learn-piece.cc
Go to the documentation of this file.
00001 /* learn-piece.cc
00002  */
00003 #include "osl/state/numEffectState.h"
00004 #include "osl/move_generator/legalMoves.h"
00005 #include "osl/container/moveVector.h"
00006 #include "osl/record/csaRecord.h"
00007 #include "osl/record/ki2.h"
00008 #include "osl/record/kakinoki.h"
00009 #include "osl/record/kisen.h"
00010 #include "osl/eval/see.h"
00011 #include "osl/pieceStand.h"
00012 #include <boost/algorithm/string/predicate.hpp>
00013 #include <iostream>
00014 using namespace osl;
00015 using namespace std;
00016 namespace csa=osl::record::csa;
00017 CArray<int,PTYPE_SIZE> weight, gradient;
00018 void show() {
00019   for (size_t i=0; i<PieceStand::order.size(); ++i) {
00020     Ptype ptype = PieceStand::order[i];
00021     cout << csa::show(ptype) << ' ' << weight[ptype] << ' ';
00022     if (canPromote(ptype))
00023       cout << csa::show(promote(ptype)) << ' ' << weight[promote(ptype)] << ' ';
00024   }
00025   cout << endl;
00026 #if 0
00027   for (size_t i=0; i<PieceStand::order.size(); ++i) {
00028     Ptype ptype = PieceStand::order[i];
00029     cout << csa::show(ptype) << ' ' << gradient[ptype] << ' ';
00030     if (canPromote(ptype))
00031       cout << csa::show(promote(ptype)) << ' ' << gradient[promote(ptype)] << ' ';
00032   }
00033   cout << endl;
00034 #endif
00035 }
00036 int median() {
00037   osl::vector<int> copy;
00038   for (int i=0; i<PTYPE_SIZE; ++i)
00039     if (gradient[i]!=0) copy.push_back(gradient[i]);
00040   sort(copy.begin(), copy.end());
00041   if (copy.size() == 1) return 0;
00042   if (copy.size()%2) return copy[copy.size()/2];
00043   return copy[copy.size()/2]-1;
00044 }
00045 void update() {
00046   std::vector<std::pair<int,Ptype> > gradient_ptype;
00047   for (size_t i=0; i<PieceStand::order.size(); ++i) {
00048     Ptype ptype = PieceStand::order[i];
00049     gradient_ptype.push_back(std::make_pair(gradient[ptype], ptype));
00050     if (canPromote(ptype)) {
00051       ptype = promote(ptype);
00052       gradient_ptype.push_back(std::make_pair(gradient[ptype], ptype));
00053     }
00054   }
00055   std::sort(gradient_ptype.begin(), gradient_ptype.end());
00056   // bonanza's robust update seems better than standard gradient descent methods, here
00057   // const int a[13] = { -1, -1, -1, -1, -1, -1, 0, 1, 1, 1, 1, 1, 1 }; 
00058   const int a[13] = { -3, -2, -2, -1, -1, -1, 0, 1, 1, 1, 2, 2, 3 };
00059   for (size_t i=0; i<gradient_ptype.size(); ++i)
00060     weight[gradient_ptype[i].second] += a[i];
00061 }
00062 void count(const NumEffectState& state, CArray<int,PTYPE_SIZE>& out) {
00063   out.fill(0);
00064   for (int i=0; i<Piece::SIZE; ++i) {
00065     Piece p = state.pieceOf(i);
00066     out[p.ptype()] += playerToSign(p.owner());
00067   }
00068 }
00069 void compare(Player turn, const NumEffectState& selected, 
00070              const NumEffectState& not_selected) {
00071   CArray<int,PTYPE_SIZE> c0, c1, diff;
00072   count(selected, c0);
00073   count(not_selected, c1);
00074   int evaldiff = 0;
00075   for (int i=0; i<PTYPE_SIZE; ++i) {
00076     diff[i] = (c0[i] - c1[i])*playerToSign(turn);
00077     evaldiff += diff[i] * weight[i];
00078   }
00079   if (evaldiff > 0) return;
00080   for (int i=0; i<PTYPE_SIZE; ++i) 
00081     gradient[i] += diff[i];
00082 }
00083 Move greedymove(const NumEffectState& state) {
00084   MoveVector all;
00085   LegalMoves::generate(state, all);
00086   int best_see = 0;
00087   Move best_move;
00088   for (size_t i=0; i<all.size(); ++i) {
00089     if (! all[i].isCaptureOrPromotion()) continue;
00090     int see = See::see(state, all[i]);
00091     if (see <= best_see) continue;
00092     best_see = see;
00093     best_move = all[i];
00094   }
00095   return best_move;
00096 }
00097 void make_PV(const NumEffectState& src, Move prev, MoveVector& pv) {
00098   NumEffectState state(src);
00099   pv.clear(); 
00100   // todo: quiescence search
00101   while (true) {
00102     state.makeMove(prev);
00103     pv.push_back(prev);
00104     Move move = greedymove(state);
00105     if (! move.isNormal())
00106       return;
00107     prev = move;
00108   }
00109 }
00110 void make_moves(NumEffectState& state, const MoveVector& pv) {
00111   for (size_t i=0; i<pv.size(); ++i)
00112     state.makeMove(pv[i]);
00113 }
00114 
00115 void run(const osl::vector<Move>& moves) {
00116   NumEffectState state;
00117   for (size_t i=0; i<moves.size(); ++i) {
00118     const Move selected = moves[i];
00119     MoveVector all;
00120     LegalMoves::generate(state, all);
00121 
00122     if (! state.hasEffectAt(alt(selected.player()), selected.to())) {
00123       MoveVector pv0;
00124       make_PV(state, selected, pv0);
00125       NumEffectState s0(state);
00126       make_moves(s0, pv0);
00127       for (size_t j=0; j<all.size(); ++j)
00128         if (all[j] != selected) {
00129           MoveVector pv1;
00130           make_PV(state, all[j], pv1);
00131           NumEffectState s1(state);
00132           make_moves(s1, pv1);
00133           compare(state.turn(), s0, s1);
00134         }
00135     }
00136     state.makeMove(selected);
00137   }
00138 }
00139 int main(int argc, char **argv) {
00140   weight.fill(500);
00141   for (int t=0; t<1024; ++t) {
00142     show();
00143     gradient.fill(0);
00144     for (int i=1; i<argc; ++i) {
00145       const char *filename = argv[i];
00146       if (boost::algorithm::iends_with(filename, ".csa")) {
00147         const CsaFile csa(filename);
00148         run(csa.getRecord().getMoves());
00149       }
00150       else if (boost::algorithm::iends_with(filename, ".ki2")) {
00151         const Ki2File ki2(filename);
00152         run(ki2.getRecord().getMoves());
00153       }
00154       else if (boost::algorithm::iends_with(filename, ".kif")
00155                && KakinokiFile::isKakinokiFile(filename)) {
00156         const KakinokiFile kif(filename);
00157         run(kif.getRecord().getMoves());
00158       }
00159       else if (boost::algorithm::iends_with(filename, ".kif")) {
00160         KisenFile kisen(filename);
00161         for (size_t j=0; j<kisen.size(); ++j)
00162           run(kisen.getMoves(j));
00163       }
00164       else {
00165         cerr << "Unknown file type: " << filename << "\n";
00166         continue;
00167       }
00168     }
00169     update();
00170   }
00171 }
00172 // ;;; Local Variables:
00173 // ;;; mode:c++
00174 // ;;; c-basic-offset:2
00175 // ;;; coding:utf-8
00176 // ;;; End:
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines