00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033 #ifndef C_BATCHLEARNING__
00034 #define C_BATCHLEARNING__
00035
00036
00037 #include "cagentcontroller.h"
00038 #include "cparameters.h"
00039 #include "cagent.h"
00040
00041 #include "newmat/newmat.h"
00042
00043 class CRewardHistory;
00044 class CEpisodeHistory;
00045 class CRewardLogger;
00046 class CAgentLogger;
00047
00048 class CContinuousActionQFunction;
00049 class CPolicySameStateEvaluator;
00050 class CGradientUpdateFunction;
00051 class CSemiMDPListener;
00052 class CLSTDLambda;
00053 class CSemiMDPRewardListener;
00054
00055 class CFeatureCalculator;
00056 class CFeatureVFunction;
00057 class CDataSet;
00058 class CDataSet1D;
00059 class CDataSubset;
00060
00061 class CAbstractVFunction;
00062 class CGradientVFunction;
00063 class CQFunction;
00064
00065 class CSupervisedLearner;
00066 class CSupervisedWeightedLearner;
00067 class CSupervisedQFunctionLearner;
00068 class CSupervisedQFunctionWeightedLearner;
00069
00070 class CStochasticPolicy;
00071
00072 class CLearnDataObject;
00073
00074 class CContinuousActionQFunction;
00075 class CStateProperties;
00076
00077 class CKDTree;
00078 class CKNearestNeighbors;
00079 class CDataPreprocessor;
00080
00081 class CSamplingBasedTransitionModel;
00082 class CSamplingBasedGraph;
00083 class CActionDistribution;
00084
00085 class CAction;
00086 class CActionSet;
00087
00088 class CState;
00089 class CStateCollection;
00090 class CStateProperties;
00091
00092 class CStep;
00093 class CNewFeatureCalculator;
00094 class CAbstractQFunction;
00095
00096
00097 class CBatchLearningPolicy : public CDeterministicController
00098 {
00099 protected:
00100 CActionDataSet *actionDataSet;
00101 CAction *nextAction;
00102 public:
00103 CBatchLearningPolicy(CActionSet *actions);
00104 ~CBatchLearningPolicy();
00105
00106 virtual CAction *getNextAction(CStateCollection *state, CActionDataSet *actionData = NULL);
00107
00108 virtual void setAction(CAction *action, CActionData *data);
00109 };
00110
00111
00112 class CPolicyEvaluation : public CParameterObject
00113 {
00114 protected:
00115
00116 public:
00117 CPolicyEvaluation(int maxEvaluations = 100);
00118 virtual ~CPolicyEvaluation();
00119
00120 virtual void evaluatePolicy();
00121 virtual void evaluatePolicy(int numEvaluations) = 0;
00122
00123 virtual void resetLearnData() {};
00124 };
00125
00126 class CPolicyEvaluationGradientFunction : public CPolicyEvaluation
00127 {
00128 protected:
00129 CGradientUpdateFunction *learnData;
00130 double *oldWeights;
00131
00132
00133 virtual double getWeightDifference(double *oldWeights);
00134 public:
00135 bool resetData;
00136
00137 CPolicyEvaluationGradientFunction(CGradientUpdateFunction *learnData, double treshold = 0.1, int maxEvaluations = 100);
00138 virtual ~CPolicyEvaluationGradientFunction();
00139
00140
00141 virtual void evaluatePolicy(int numEvaluations) = 0;
00142
00143 virtual void resetLearnData();
00144 };
00145
00146 class COnlinePolicyEvaluation : public CPolicyEvaluationGradientFunction
00147 {
00148 protected:
00149 CAgent *agent;
00150 CSemiMDPListener *learner;
00151
00152 CSemiMDPSender *semiMDPSender;
00153 public:
00154
00155 COnlinePolicyEvaluation(CAgent *agent, CSemiMDPListener *learner, CGradientUpdateFunction *learnData, int maxEvaluationEpisodes, int numSteps, int checkWeightsPerEpisode);
00156 virtual ~COnlinePolicyEvaluation();
00157
00158 virtual void evaluatePolicy(int numEvaluations);
00159
00160 void setSemiMDPSender(CSemiMDPSender *semiMDPSender);
00161 };
00162
00163 class CLSTDOnlinePolicyEvaluation : public COnlinePolicyEvaluation
00164 {
00165 protected:
00166 CLSTDLambda *lstdLearner;
00167
00168 virtual double getWeightDifference(double *oldWeights);
00169 public:
00170
00171 CLSTDOnlinePolicyEvaluation(CAgent *agent, CLSTDLambda *learner, CGradientUpdateFunction *learnData, int maxEvaluationSteps, int nSteps);
00172 virtual ~CLSTDOnlinePolicyEvaluation();
00173
00174 virtual void resetLearnData();
00175
00176 };
00177
00178
00179
00180
00181
00182
00183
00184
00185
00186
00187
00188
00189
00190
00191
00192
00193
00194
00195
00196
00197
00198
00199
00200
00201
00202
00203
00204
00205
00206
00207
00208
00209 class COfflineEpisodePolicyEvaluation : public CPolicyEvaluationGradientFunction, public CSemiMDPSender
00210 {
00211 protected:
00212 CSemiMDPRewardListener *learner;
00213 CEpisodeHistory *episodeHistory;
00214 CRewardHistory *rewardLogger;
00215
00216 std::list<CStateModifier *> *modifiers;
00217
00218 CBatchLearningPolicy *policy;
00219 public:
00220
00221
00222 COfflineEpisodePolicyEvaluation(CEpisodeHistory *episodeHistory, CSemiMDPRewardListener *learner, CGradientUpdateFunction *learnData, std::list<CStateModifier *> *l_modifiers, int maxEvaluationEpisodes);
00223 COfflineEpisodePolicyEvaluation(CEpisodeHistory *episodeHistory, CRewardHistory *rewardLogger, CSemiMDPRewardListener *learner, CGradientUpdateFunction *learnData, std::list<CStateModifier *> *l_modifiers, int maxEvaluationEpisodes);
00224 virtual ~COfflineEpisodePolicyEvaluation();
00225
00226 void setBatchLearningPolicy(CBatchLearningPolicy *l_policy);
00227
00228 virtual void evaluatePolicy(int numEvaluations);
00229 };
00230
00231
00232 class CLSTDOfflineEpisodePolicyEvaluation : public COfflineEpisodePolicyEvaluation
00233 {
00234 protected:
00235 CLSTDLambda *lstdLearner;
00236
00237 virtual double getWeightDifference(double *oldWeights);
00238 public:
00239
00240 CLSTDOfflineEpisodePolicyEvaluation(CEpisodeHistory *episodeHistory, CRewardHistory *rewardLogger, CLSTDLambda *learner, CGradientVFunction *learnData, std::list<CStateModifier *> *l_modifiers, int episodes);
00241 virtual ~CLSTDOfflineEpisodePolicyEvaluation();
00242
00243 virtual void resetLearnData();
00244 };
00245
00246
00247 class CDataCollector : public CParameterObject
00248 {
00249 protected:
00250
00251 public:
00252 CDataCollector();
00253 virtual ~CDataCollector();
00254
00255 virtual void collectData() = 0;
00256
00257 };
00258
00259 class CUnknownDataQFunction;
00260
00261 class CDataCollectorFromAgentLogger : public CDataCollector
00262 {
00263 protected:
00264 CAgentLogger *logger;
00265 CRewardLogger *rewardLogger;
00266 CAgent *agent;
00267
00268 int numCollections;
00269 CSemiMarkovDecisionProcess *sender;
00270
00271 std::list<CUnknownDataQFunction *> *unknownDataQFunctions;
00272 CAgentController *controller;
00273 public:
00274 CDataCollectorFromAgentLogger(CAgent *agent, CAgentLogger *logger, CRewardLogger *rewardLogger, int numEpisodes, int numSteps);
00275 virtual ~CDataCollectorFromAgentLogger();
00276
00277 virtual void collectData();
00278
00279 virtual void setController(CAgentController *controller);
00280 virtual void addUnknownDataFunction(CUnknownDataQFunction *unknownDataQFunctions);
00281
00282 void setSemiMDPSender(CSemiMarkovDecisionProcess *sender);
00283 };
00284
00285
00286
00287 class CPolicyIteration : virtual public CParameterObject
00288 {
00289 protected:
00290 CLearnDataObject *policyFunction;
00291 CLearnDataObject *evaluationFunction;
00292
00293 CPolicyEvaluation *evaluation;
00294 CDataCollector *collector;
00295 public:
00296
00297 CPolicyIteration(CLearnDataObject *policyFunction, CLearnDataObject *evaluationFunction, CPolicyEvaluation *evaluation, CDataCollector *collector = NULL);
00298 virtual ~CPolicyIteration();
00299
00300 virtual void doPolicyIterationStep();
00301 virtual void initPolicyIteration();
00302 };
00303
00304
00305
00306 class CNewFeatureCalculatorDataGenerator
00307 {
00308 protected:
00309
00310 public:
00311 virtual ~CNewFeatureCalculatorDataGenerator() {};
00312
00313 virtual void initFeatures() {calculateNewFeatures();};
00314 virtual void calculateNewFeatures() = 0;
00315
00316 virtual void swapValueFunctions() = 0;
00317
00318 virtual void resetData() {};
00319 };
00320
00321 class CPolicyIterationNewFeatures : public CPolicyIteration
00322 {
00323 protected:
00324 CNewFeatureCalculatorDataGenerator *newFeatureCalculator;
00325 public:
00326
00327 CPolicyIterationNewFeatures(CLearnDataObject *policyFunction, CLearnDataObject *evaluationFunction, CPolicyEvaluation *evaluation, CNewFeatureCalculatorDataGenerator *newFeatureCalculator, CDataCollector *collector = NULL);
00328 virtual ~CPolicyIterationNewFeatures();
00329
00330 virtual void doPolicyIterationStep();
00331 virtual void initPolicyIteration();
00332 };
00333
00334
00335
00336
00337
00338
00339
00340
00341
00342
00343
00344
00345
00346
00347
00348
00349
00350
00351
00352
00353
00354
00355
00356
00357
00358
00359
00360
00361
00362
00363
00364
00365
00366
00367
00368
00369
00370
00371
00372
00373
00374
00375
00376
00377
00378
00379
00380
00381
00382
00383
00384
00385
00386
00387
00388
00389
00390
00391
00392
00393
00394
00395
00396
00397
00398
00399
00400
00401
00402
00403
00404
00405
00406
00407
00408
00409
00410
00411
00412
00413 class CBatchDataGenerator : public CParameterObject
00414 {
00415 protected:
00416 public:
00417 CBatchDataGenerator() {};
00418 virtual ~CBatchDataGenerator() {};
00419
00420 virtual void addInput(CStateCollection *state, CAction *action, double output, double weighting = 1.0) = 0;
00421
00422 virtual void trainFA() = 0;
00423 virtual void resetPolicyEvaluation() = 0;
00424
00425 virtual double getValue(CStateCollection *state, CAction *action) = 0;
00426
00427 void generateInputData(CEpisodeHistory *logger);
00428 };
00429
00430 class CBatchVDataGenerator : public CBatchDataGenerator
00431 {
00432 protected:
00433 CDataSet *inputData;
00434 CDataSet1D *outputData;
00435 CDataSet1D *weightingData;
00436 ColumnVector *buffVector;
00437
00438 CAbstractVFunction *vFunction;
00439
00440 CSupervisedLearner *learner;
00441 CSupervisedWeightedLearner *weightedLearner;
00442
00443 CBatchVDataGenerator(CSupervisedLearner *learner, int inputDim);
00444
00445
00446 public:
00447 CBatchVDataGenerator(CAbstractVFunction *vFunction, CSupervisedLearner *learner);
00448 CBatchVDataGenerator(CAbstractVFunction *vFunction, CSupervisedWeightedLearner *learner);
00449 virtual ~CBatchVDataGenerator();
00450
00451 virtual void init(int numDim);
00452
00453 virtual void addInput(CStateCollection *state, CAction *action, double output, double weighting = 1.0);
00454
00455 virtual void trainFA();
00456 virtual void resetPolicyEvaluation();
00457
00458 virtual double getValue(CStateCollection *state, CAction *action);
00459
00460 virtual CDataSet *getInputData();
00461 virtual CDataSet1D *getOutputData();
00462 virtual CDataSet1D *getWeighting();
00463 };
00464
00465
00466
00467 class CBatchCAQDataGenerator : public CBatchVDataGenerator
00468 {
00469 protected:
00470 CContinuousActionQFunction *qFunction;
00471 CStateProperties *properties;
00472
00473
00474 public:
00475 CBatchCAQDataGenerator(CStateProperties *properties, CContinuousActionQFunction *qFunction, CSupervisedLearner *learner);
00476 virtual ~CBatchCAQDataGenerator();
00477
00478 virtual void addInput(CStateCollection *state, CAction *action, double output);
00479
00480 virtual double getValue(CStateCollection *state, CAction *action);
00481 };
00482
00483
00484 class CBatchQDataGenerator : public CBatchDataGenerator
00485 {
00486 protected:
00487 std::map<CAction *, CDataSet *> *inputMap;
00488 std::map<CAction *, CDataSet1D *> *outputMap;
00489 std::map<CAction *, ColumnVector *> *buffVectorMap;
00490 std::map<CAction *, CDataSet1D *> *weightedMap;
00491
00492
00493 CQFunction *qFunction;
00494 CStateProperties *properties;
00495 CActionSet *actions;
00496
00497 CSupervisedQFunctionLearner *learner;
00498 CSupervisedQFunctionWeightedLearner *weightedLearner;
00499 public:
00500 CBatchQDataGenerator(CQFunction *qFunction, CSupervisedQFunctionLearner *learner, CStateProperties *inputState = NULL);
00501 CBatchQDataGenerator(CQFunction *qFunction, CSupervisedQFunctionWeightedLearner *learner, CStateProperties *inputState = NULL);
00502 CBatchQDataGenerator(CActionSet *actions, CStateProperties *properties);
00503
00504 void init(CQFunction *qFunction, CActionSet *actions, CStateProperties *properties);
00505
00506 virtual ~CBatchQDataGenerator();
00507
00508 virtual void addInput(CStateCollection *state, CAction *action, double output, double weighting = 1.0);
00509
00510 virtual void trainFA();
00511 virtual void resetPolicyEvaluation();
00512
00513 virtual double getValue(CStateCollection *state, CAction *action);
00514
00515 CDataSet *getInputData(CAction *action);
00516 CDataSet1D *getOutputData(CAction *action);
00517
00518 CStateProperties *getStateProperties(CAction *action);
00519 };
00520
00521
00522
00523 class CFittedIteration : public CPolicyEvaluation
00524 {
00525 protected:
00526 CAgentController *estimationPolicy;
00527 CBatchDataGenerator *dataGenerator;
00528
00529 CEpisodeHistory *episodeHistory;
00530 CRewardHistory *rewardLogger;
00531
00532 CDataCollector *dataCollector;
00533 CPolicyEvaluation *actorLearner;
00534
00535 virtual void addResidualInput(CStep *step, CAction *action, double oldV, double newV, double nearestNeighborDistance, CAction *nextHistoryActon = NULL, double nextReward = 0.0);
00536
00537 CPolicyEvaluation *initialPolicyEvaluation;
00538
00539 virtual double getWeighting(CStateCollection *state, CAction *action);
00540
00541 virtual double getValue(CStateCollection *state, CAction *action);
00542
00543 int useResidualAlgorithm;
00544
00545 virtual void onParametersChanged();
00546 public:
00547
00548
00549 CFittedIteration(CEpisodeHistory *episodeHistory, CRewardHistory *rewardLogger, CBatchDataGenerator *dataGenerator);
00550
00551 virtual ~CFittedIteration();
00552
00553
00554
00555 virtual void doEvaluationTrial();
00556 virtual void evaluatePolicy(int trials);
00557
00558 virtual CBatchDataGenerator *createTrainingsData();
00559
00560 virtual void setDataCollector(CDataCollector *dataCollector);
00561
00562 virtual void setInitialPolicyEvaluation(CPolicyEvaluation *initialPolicyEvaluation);
00563
00564 virtual void resetLearnData();
00565
00566 void setActorLearner(CPolicyEvaluation *actorLearner);
00567 virtual void evaluatePolicy();
00568 };
00569
00570 class CFittedCAQIteration : public CFittedIteration
00571 {
00572 protected:
00573
00574
00575 public:
00576 CFittedCAQIteration(CContinuousActionQFunction *qFunction, CStateProperties *properties, CAgentController *estimationPolicy, CEpisodeHistory *episodeHistory, CRewardHistory *rewardLogger, CSupervisedLearner *learner);
00577
00578 virtual ~CFittedCAQIteration();
00579 };
00580
00581
00582 class CFittedVIteration : public CFittedIteration
00583 {
00584 protected:
00585 CStochasticPolicy *estimationPolicy;
00586
00587 virtual double getWeighting(CStateCollection *state, CAction *action);
00588
00589 double *actionProbabilities;
00590 CActionSet *availableActions;
00591 public:
00592 CFittedVIteration(CAbstractVFunction *vFunction, CEpisodeHistory *episodeHistory, CRewardHistory *rewardLogger, CSupervisedLearner *learner);
00593
00594
00595 CFittedVIteration(CAbstractVFunction *vFunction, CEpisodeHistory *episodeHistory, CRewardHistory *rewardLogger, CSupervisedWeightedLearner *learner, CStochasticPolicy *estimationPolicy);
00596
00597 virtual ~CFittedVIteration();
00598
00599
00600 };
00601
00602 class CFittedQIteration : public CFittedIteration
00603 {
00604 protected:
00605 std::map<CAction *, CDataSet *> *inputDatas;
00606 std::map<CAction *, CDataSet1D *> *outputDatas;
00607 std::map<CAction *, CKDTree *> *kdTrees;
00608 std::map<CAction *, CKNearestNeighbors *> *nearestNeighbors;
00609 std::map<CAction *, CDataPreprocessor *> *dataPreProc;
00610
00611 std::list<int> *neighborsList;
00612
00613 CStateProperties *residualProperties;
00614
00615 int kNN;
00616
00617 CState *buffState;
00618
00619 virtual void addResidualInput(CStep *step, CAction *action, double oldV, double newV, double nearestNeighborDistance, CAction *nextHistoryActon = NULL, double nextReward = 0.0);
00620
00621 public:
00622 CFittedQIteration(CQFunction *qFunction, CAgentController *estimationPolicy, CEpisodeHistory *episodeHistory, CRewardHistory *rewardLogger, CSupervisedQFunctionLearner *learner, CStateProperties *residualProperties = NULL);
00623
00624 CFittedQIteration(CQFunction *qFunction, CStateProperties *inputState, CAgentController *estimationPolicy, CEpisodeHistory *episodeHistory, CRewardHistory *rewardLogger, CSupervisedQFunctionLearner *learner, CStateProperties *residualProperties = NULL);
00625
00626
00627 virtual ~CFittedQIteration();
00628
00629 virtual void doEvaluationTrial();
00630 };
00631
00632
00633
00634 class CFittedQNewFeatureCalculator : public CFittedQIteration, public CNewFeatureCalculatorDataGenerator
00635 {
00636 protected:
00637 CQFunction *qFunction;
00638 CQFunction *qFunctionPolicy;
00639
00640 std::map<CAction *, CFeatureCalculator *> *policyCalculator;
00641 std::map<CAction *, CFeatureCalculator *> *estimationCalculator;
00642
00643 CNewFeatureCalculator *newFeatureCalc;
00644 CStateModifiersObject *agent;
00645 public:
00646 CFittedQNewFeatureCalculator(CQFunction *qFunction, CQFunction *qFunctionPolicy, CStateProperties *inputState, CAgentController *estimationPolicy, CEpisodeHistory *episodeHistory, CRewardHistory *rewardLogger, CNewFeatureCalculator *newFeatCalc);
00647
00648 virtual ~CFittedQNewFeatureCalculator();
00649
00650 virtual void calculateNewFeatures();
00651
00652 virtual void swapValueFunctions();
00653
00654 void clearCalculators();
00655 virtual void resetData();
00656
00657 void setStateModifiersObject(CStateModifiersObject *agent);
00658
00659 };
00660
00661 class CActionDistribution;
00662
00663 class CContinuousDynamicProgramming : public CPolicyEvaluation
00664 {
00665 protected:
00666 CSamplingBasedTransitionModel *transModel;
00667
00668 CActionSet *availableActions;
00669
00670
00671 double *actionValues;
00672 double *actionProbabilities;
00673
00674 int numIteration;
00675 public:
00676 CContinuousDynamicProgramming(CActionSet *allActions, CSamplingBasedTransitionModel *transModel);
00677
00678 virtual ~CContinuousDynamicProgramming();
00679
00680 virtual void evaluatePolicy(int numEvaluations);
00681
00682 virtual double getValueFromDistribution(CActionSet *availableActions, double *actionValues, CActionDistribution *distribution);
00683
00684 virtual double getValue(CState *state, CActionSet *availableActions) = 0;
00685 virtual void updateOutputs(int index, CActionSet *availableActions, double *actionValues) = 0;
00686
00687 virtual void learn() = 0;
00688
00689 virtual void resetLearnData();
00690 virtual void resetDynamicProgramming() = 0;
00691 };
00692
00693 class CContinuousDynamicVProgramming : public CContinuousDynamicProgramming
00694 {
00695 protected:
00696 CDataSet1D *outputValues;
00697 CActionDistribution *distribution;
00698
00699 CSupervisedLearner *learner;
00700 CAbstractVFunction *vFunction;
00701 public:
00702 CContinuousDynamicVProgramming(CActionSet *allActions, CActionDistribution *distribution, CSamplingBasedTransitionModel *transModel, CAbstractVFunction *vFunction, CSupervisedLearner *learner);
00703
00704 virtual ~CContinuousDynamicVProgramming();
00705
00706 virtual double getValue(CState *state, CActionSet *availableActions);
00707 virtual void updateOutputs(int index, CActionSet *availableActions, double *actionValues);
00708
00709 virtual void learn();
00710
00711 virtual void resetDynamicProgramming();
00712 };
00713
00714 class CContinuousDynamicQProgramming : public CContinuousDynamicProgramming
00715 {
00716 protected:
00717 std::map<CAction *, CDataSet1D *> *outputValues;
00718 std::map<CAction *, CDataSet *> *inputValues;
00719
00720 CSupervisedQFunctionLearner *learner;
00721 CAbstractQFunction *qFunction;
00722
00723 CActionDistribution *distribution;
00724 double *actionValues2;
00725 public:
00726 CContinuousDynamicQProgramming(CActionSet *allActions, CActionDistribution *distribution, CSamplingBasedTransitionModel *transModel, CAbstractQFunction *vFunction, CSupervisedQFunctionLearner *learner);
00727
00728 virtual ~CContinuousDynamicQProgramming();
00729
00730 virtual double getValue(CState *state, CActionSet *availableActions);
00731 virtual void updateOutputs(int index, CActionSet *availableActions, double *actionValues);
00732
00733 virtual void learn();
00734
00735 virtual void resetDynamicProgramming();
00736 };
00737
00738 class CContinuousMCQEvaluation : public CContinuousDynamicQProgramming
00739 {
00740 protected:
00741 CPolicySameStateEvaluator *evaluator;
00742 public:
00743 CContinuousMCQEvaluation(CActionSet *allActions, CActionDistribution *distribution, CSamplingBasedTransitionModel *transModel, CPolicySameStateEvaluator *evaluator, CSupervisedQFunctionLearner *learner);
00744
00745 virtual ~CContinuousMCQEvaluation();
00746
00747 virtual double getValue(CState *state, CActionSet *availableActions);
00748 };
00749
00750
00751 class CGraphTransition;
00752 class DataSubset;
00753 class CKDRectangle;
00754
00755 class CGraphDynamicProgramming : public CPolicyEvaluation
00756 {
00757 protected:
00758 CSamplingBasedGraph *transModel;
00759 CDataSet1D *outputValues;
00760
00761
00762 public:
00763 bool resetGraph;
00764
00765 CGraphDynamicProgramming(CSamplingBasedGraph *transModel);
00766
00767 virtual ~CGraphDynamicProgramming();
00768
00769 virtual void evaluatePolicy(int numEvaluations);
00770
00771 virtual void resetLearnData();
00772
00773 virtual double getValue(int node);
00774 virtual double getValue(ColumnVector *input);
00775
00776 virtual CGraphTransition *getMaxTransition(int index, double &maxValue, CKDRectangle *range = NULL);
00777
00778 virtual void getNearestNode(ColumnVector *input, int &node, double &distance);
00779
00780 CSamplingBasedGraph *getGraph();
00781
00782 CDataSet1D *getOutputValues();
00783
00784 virtual void saveCSV(string filename, DataSubset *nodeSubset);
00785 };
00786
00787 class CGraphTarget;
00788 class CAdaptiveTargetGraph;
00789
00790 class CGraphAdaptiveTargetDynamicProgramming : public CGraphDynamicProgramming
00791 {
00792 protected:
00793 std::map<CGraphTarget *, CDataSet1D *> *targetMap;
00794 std::list<CGraphTarget *> *targets;
00795
00796 CGraphTarget *currentTarget;
00797 CAdaptiveTargetGraph *adaptiveTargetGraph;
00798 public:
00799 CGraphAdaptiveTargetDynamicProgramming(CAdaptiveTargetGraph *graph);
00800
00801 ~CGraphAdaptiveTargetDynamicProgramming();
00802
00803 virtual CGraphTransition *getMaxTransition(int index, double &maxValue, CKDRectangle *range = NULL);
00804
00805 virtual void addTarget(CGraphTarget *target);
00806 virtual CGraphTarget *getTargetForState(CStateCollection *state);
00807
00808 virtual void setCurrentTarget(CGraphTarget *target);
00809
00810 int getNumTargets();
00811 CGraphTarget *getTarget(int index);
00812
00813 virtual void resetLearnData();
00814 };
00815
00816
00817
00818 #endif
00819