00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032 #ifndef CGRIDWORDMODEL_H
00033 #define CGRIDWORDMODEL_H
00034
00035 #ifdef WIN32
00036 #include <windows.h>
00037 #endif // WIN32
00038
00039 #include "ctransitionfunction.h"
00040 #include "crewardfunction.h"
00041 #include "cagentcontroller.h"
00042 #include "cdiscretizer.h"
00043 #include "caction.h"
00044
00045 #include <math.h>
00046 #include <vector>
00047 #include <map>
00048 #include <set>
00049
00050 #define GRIDWORLDACTION 16
00051
00052 class CGridWorld
00053 {
00054 protected:
00055 int size_x, size_y;
00056 std::set<char>* start_values;
00057 std::set<char>* target_values;
00058 std::set<char>* prohibited_values;
00059
00060 std::vector<char *>* grid;
00061
00062 void allocGrid();
00063 void deallocGrid();
00064
00065 bool is_allocated;
00066
00067 public:
00068 CGridWorld(char* filename);
00069 CGridWorld(unsigned int size_x, unsigned int size_y);
00070 virtual ~CGridWorld();
00071
00072 void load(char* filename);
00073 virtual void load(FILE *stream);
00074
00075 void save(char* filename);
00076 virtual void save(FILE *stream);
00077
00078 virtual void initGrid();
00079 virtual bool isValid();
00080
00081 virtual void setGridValue(unsigned int pos_x, unsigned int pos_y, char value);
00082 char getGridValue(unsigned int pos_x, unsigned int pos_y);
00083
00084 virtual void addStartValue(char value);
00085 virtual void removeStartValue(char value);
00086 std::set<char> *getStartValues();
00087
00088 virtual void addTargetValue(char value);
00089 virtual void removeTargetValue(char value);
00090 std::set<char> *getTargetValues();
00091
00092 virtual void addProhibitedValue(char value);
00093 virtual void removeProhibitedValue(char value);
00094 std::set<char> *getProhibitedValues();
00095
00096 void setSize(unsigned int size_x, unsigned int size_y);
00097
00098 unsigned int getSizeX();
00099 unsigned int getSizeY();
00100
00101 std::set<char> *getUsedValues();
00102 };
00103
00104
00105 class CGridWorldModel : public CGridWorld, public CTransitionFunction, public CRewardFunction
00106 {
00107 protected:
00108 unsigned int max_bounces;
00109
00110
00111 std::vector<std::pair<int, int>* >* start_points;
00112 std::map<char, double> *rewards;
00113
00114 double reward_standard;
00115 double reward_success;
00116 double reward_bounce;
00117
00118 bool is_parsed;
00119 virtual void parseGrid();
00120
00121 public:
00122 CGridWorldModel(unsigned int size_x, unsigned int size_y, unsigned int max_bounces);
00123 CGridWorldModel(char* filename, unsigned int max_bounces);
00124 virtual ~CGridWorldModel();
00125
00126 void setMaxBounces(unsigned int value);
00127 unsigned int getMaxBounces();
00128
00129
00130
00131
00132
00133
00134
00135
00136
00137
00138 void setRewardStandard(double value);
00139 void setRewardSuccess(double value);
00140 void setRewardBounce(double value);
00141
00142 void setRewardForSymbol(char symbol, double reward);
00143 double getRewardForSymbol(char symbol);
00144
00145 double getRewardStandard();
00146 double getRewardSuccess();
00147 double getRewardBounce();
00148
00149 virtual void load(FILE *stream);
00150 virtual void initGrid();
00151 virtual void setGridValue(unsigned int pos_x, unsigned int pos_y, char value);
00152 virtual void addStartValue(char value);
00153 virtual void removeStartValue(char value);
00154
00155 virtual void transitionFunction(CState *oldstate, CAction *action, CState *newState, CActionData *data = NULL);
00156
00157 virtual bool isResetState(CState *state);
00158 virtual bool isFailedState(CState *state);
00159
00160 virtual void getResetState(CState *resetState);
00161
00162 virtual double getReward(CStateCollection *oldState, CAction *action, CStateCollection *newState);
00163
00164 };
00165
00166
00167 class CLocal4GridWorldState : public CStateModifier
00168 {
00169 protected:
00170 CGridWorld* grid_world;
00171 public:
00172 CLocal4GridWorldState(CGridWorld *grid_world);
00173 virtual ~CLocal4GridWorldState();
00174
00175 virtual void getModifiedState(CStateCollection *originalState, CState *modifiedState);
00176 };
00177
00178
00179 class CLocal4XGridWorldState : public CStateModifier
00180 {
00181 protected:
00182 CGridWorld* grid_world;
00183 public:
00184 CLocal4XGridWorldState(CGridWorld *grid_world);
00185 virtual ~CLocal4XGridWorldState();
00186
00187 virtual void getModifiedState(CStateCollection *originalState, CState *modifiedState);
00188 };
00189
00190
00191 class CLocal8GridWorldState : public CStateModifier
00192 {
00193 protected:
00194 CGridWorld* grid_world;
00195 public:
00196 CLocal8GridWorldState(CGridWorld *grid_world);
00197 virtual ~CLocal8GridWorldState();
00198
00199 virtual void getModifiedState(CStateCollection *originalState, CState *modifiedState);
00200 };
00201
00202
00203 class CGlobalGridWorldDiscreteState : public CAbstractStateDiscretizer
00204 {
00205 protected:
00206 unsigned int size_x, size_y;
00207
00208 public:
00209 CGlobalGridWorldDiscreteState(unsigned int size_x, unsigned int size_y);
00210 virtual ~CGlobalGridWorldDiscreteState() {};
00211
00212 virtual unsigned int getDiscreteStateNumber(CStateCollection *state);
00213 };
00214
00215
00216 class CLocalGridWorldDiscreteState : public CAbstractStateDiscretizer
00217 {
00218 protected:
00219 CStateModifier* orig_state;
00220 std::map<char, short>* valuemap;
00221
00222 public:
00223 CLocalGridWorldDiscreteState(CStateModifier* orig_state, unsigned int neigbourhood, std::set<char> *possible_values);
00224 virtual ~CLocalGridWorldDiscreteState();
00225
00226 virtual unsigned int getDiscreteStateNumber(CStateCollection *state);
00227 };
00228
00229
00230 class CSmallLocalGridWorldDiscreteState : public CAbstractStateDiscretizer
00231 {
00232 protected:
00233 CStateModifier* orig_state;
00234 CGridWorld *gridworld;
00235
00236 public:
00237 CSmallLocalGridWorldDiscreteState(CStateModifier* orig_state, unsigned int neigbourhood, CGridWorld *gridworld);
00238 virtual ~CSmallLocalGridWorldDiscreteState();
00239
00240 virtual unsigned int getDiscreteStateNumber(CStateCollection *state);
00241 };
00242
00243
00244 class CGridWorldAction : public CPrimitiveAction
00245 {
00246 protected:
00247 int x_move, y_move;
00248
00249 public:
00250 CGridWorldAction(int x_move, int y_move);
00251
00252 int getXMove();
00253 int getYMove();
00254 };
00255
00256
00257 class CGridWorldController : public CAgentStatisticController, public CSemiMDPListener
00258 {
00259 struct GridControllerRecord
00260 {
00261 CGridWorldAction* action;
00262 int pos_x;
00263 int pos_y;
00264 double factor;
00265 double distance;
00266 };
00267
00268 protected:
00269 CGridWorld *gridworld;
00270 std::vector<GridControllerRecord> *record;
00271 std::set<std::pair<unsigned int, unsigned int>*>* target_points;
00272 int lastXMove, lastYMove;
00273
00274 public:
00275 CGridWorldController(CGridWorld *gridworld, CActionSet *actions);
00276 virtual ~CGridWorldController();
00277
00278 void init();
00279
00280 virtual CAction* getNextAction(CStateCollection *state, CActionStatistics *stat);
00281
00282 virtual void nextStep(CStateCollection *, CAction *, CStateCollection *) {};
00283
00284 virtual void newEpisode();
00285 };
00286
00287
00288 #ifdef WIN32
00289
00290 class CGridWorldVisualizer : public CSemiMDPListener
00291 {
00292 protected:
00293 CGridWorldModel *gridworld;
00294 bool flgDisplay;
00295 bool flgTranspose;
00296 HANDLE console;
00297 short xpos, ypos, xoffset, yoffset;
00298
00299 public:
00300 CGridWorldVisualizer(CGridWorldModel *gridworld);
00301 virtual ~CGridWorldVisualizer();
00302
00303 virtual void nextStep(CStateCollection *oldState, CAction *action, CStateCollection *nextState);
00304
00305 virtual void intermediateStep(CStateCollection *oldState, CAction *action, CStateCollection *nextState) {};
00306
00307 virtual void newEpisode();
00308
00309 bool getDisplay();
00310
00311 void setDisplay(bool flgDisplay);
00312 };
00313
00314 #endif // WIN32
00315
00316
00317 class CRaceTrackDiscreteState : public CAbstractStateDiscretizer
00318 {
00319 protected:
00320 CStateModifier* orig_state;
00321 CGridWorld *gridworld;
00322
00323 public:
00324 CRaceTrackDiscreteState(CStateModifier* orig_state, unsigned int neigbourhood, CGridWorld *gridworld);
00325 virtual ~CRaceTrackDiscreteState();
00326
00327 virtual unsigned int getDiscreteStateNumber(CStateCollection *state);
00328 };
00329
00330
00331 class CRaceTrack
00332 {
00333 public:
00334 static void generateRaceTrack(CGridWorld *gridworld, unsigned int width = 40, unsigned int length = 200, unsigned int h_max = 5, unsigned int dy_min = 1, unsigned int dy_max = 8);
00335 };
00336
00337
00338 #endif // CGRIDWORDMODEL_H
00339