Browse Source

add callback to optimizer

Willi Zschiebsch 4 years ago
parent
commit
fbeae60d9d

+ 5 - 0
lib/include/Connector.h

@@ -31,6 +31,8 @@ namespace mdd {
 			{"state", std::bind(&Connector::state,this,std::placeholders::_1)}
         };
 
+		std::function <json(const json&)> _callback;
+
 		/**
 		*subject
 		*object
@@ -59,9 +61,12 @@ namespace mdd {
 
 		json load(const json& args);
 
+		json optimizerCallback(const json& callback);
+
 	public:
 		Connector();
 		json decode(const json& request);
 		json encode();
+		void attachCallback(std::function<json(const json&)> callback);
 	};
 }

+ 1 - 0
lib/include/IOptimizer.h

@@ -15,6 +15,7 @@ namespace mdd {
 			
 			virtual bool setEvaluation(std::string func) = 0;
 			virtual double update() = 0;
+			virtual void attachCallback(std::function<json(const json&)> callback) = 0;
 	};
 }
 #endif // !IOPTIMIZER_H

+ 4 - 0
lib/include/OptimizerBase.h

@@ -21,6 +21,8 @@ namespace mdd {
 		std::vector<double> _output_vals;
 		exprtk::expression<double> _func_expr;
 
+		std::function<json(const json&)> _callback;
+
 		struct opt_state {
 			state module_state = state::STATE_ERROR;
 			double opt_value = 0;
@@ -41,6 +43,8 @@ namespace mdd {
 		void load(const json& j);
 		json dump();
 		json getIdentifier() override;
+
+		void attachCallback(std::function<json(const json&)> callback) override;
 		//state update() override;
 	};
 }

+ 2 - 0
lib/include/OptimizerEvolutionary.h

@@ -33,6 +33,8 @@ namespace mdd {
 		size_t _converges;
 		double _precision;
 
+		int generation = -1;
+
 		void evolve(std::vector<Individual> parents);
 		void evaluate(size_t ignore_len = 0);
 		std::vector<double> mutateGene(std::shared_ptr<IInput> input, std::vector<double> seed = std::vector<double>());

+ 9 - 1
lib/src/Connector.cpp

@@ -398,13 +398,17 @@ namespace mdd {
 			_opt = regi.generateOptimizer("OptimizerEvolutionary");
 			_opt->connect(_root);
 		}
-		
+		_opt->attachCallback(std::bind(&Connector::optimizerCallback,this,std::placeholders::_1));
 		json ret;
 		ret["operation"] = "state";
 		ret["args"].push_back(encode());
 		return ret;
 	}
 
+	json Connector::optimizerCallback(const json& callback) {
+		return _callback(callback);
+	}
+
 	Connector::Connector()
 	{
 		
@@ -442,5 +446,9 @@ namespace mdd {
 		return ret;
 	}
 
+	void Connector::attachCallback(std::function<json(const json&)> callback)
+	{
+		_callback = callback;
+	}
 }
 

+ 5 - 1
lib/src/OptimizerBase.cpp

@@ -21,7 +21,7 @@ namespace mdd{
 			ret.opt_value = _func_expr.value();
 			//std::cout << "get: " << "opt" << ": " << ret.opt_value << std::endl;
 		}
-		
+
 		return ret;
 	}
 
@@ -154,4 +154,8 @@ namespace mdd{
 		//jID["prefix"] = std::vector<std::string>();
 		return jID;
 	}
+
+	void OptimizerBase::attachCallback(std::function<json(const json&)> callback) {
+		_callback = callback;
+	}
 }

+ 68 - 5
lib/src/OptimizerEvolutionary.cpp

@@ -12,6 +12,26 @@ namespace mdd {
 				_inputs[j]->setValue() = it->dna[j];
 			}
 			auto opt = updateOutputs();
+			json ret;
+			ret["generation"] = generation;
+			json jind;
+			jind["fitness"] = it->fitness;
+			jind["dna"] = it->dna;
+			ret["individual"] = jind;
+			if (opt.module_state == state::STATE_ERROR)
+			{
+				ret["state"] = "error";
+			}
+			else {
+				ret["state"] = "ok";
+			}
+			ret["processor"] = _module->dump();
+			json jcall;
+			jcall["operation"] = "change";
+			jcall["args"]["subject"] = getIdentifier();
+			jcall["args"]["object"] = ret;
+			_callback(jcall);
+
 			if (opt.module_state == state::STATE_ERROR) {
 				_children.erase(it);
 			}
@@ -45,6 +65,21 @@ namespace mdd {
 				}
 			}
 		}
+		
+		json ret;
+		ret["generation"] = generation;
+		for (auto& ind : _bests)
+		{
+			json jind;
+			jind["fitness"] = ind.fitness;
+			jind["dna"] = ind.dna;
+			ret["bests"].push_back(jind);
+		}
+		json jcall;
+		jcall["operation"] = "add";
+		jcall["args"]["subject"] = getIdentifier();
+		jcall["args"]["object"] = ret;
+		_callback(jcall);
 	}
 
 	OptimizerEvolutionary::OptimizerEvolutionary()
@@ -147,7 +182,7 @@ namespace mdd {
 
 	double OptimizerEvolutionary::update()
 	{
-		int gen = -1;
+		generation = -1;
 		bool check;
 		Individual old_best;
 		size_t same_counter = 0;
@@ -158,17 +193,31 @@ namespace mdd {
 		}
 		do
 		{
-			std::cout << _children.size() << " | " << gen << std::endl;
+			std::cout << _children.size() << " | " << generation << std::endl;
 			if (_children.empty())
 			{
 				for (size_t i = 0; i < _grow_generation*2; i++)
 				{
 					_children.push_back(generateIndividual());
 				}
+				json ret;
+				ret["generation"] = generation;
+				for (auto& ind : _children)
+				{
+					json jind;
+					jind["fitness"] = ind.fitness;
+					jind["dna"] = ind.dna;
+					ret["individuals"] = jind;
+				}
+				json jcall;
+				jcall["operation"] = "add";
+				jcall["args"]["subject"] = getIdentifier();
+				jcall["args"]["object"] = ret;
+				_callback(jcall);
 				evaluate();
 			}
 			else {
-				if (gen != -1)
+				if (generation != -1)
 				{
 					evolve(_children);
 				}
@@ -177,8 +226,8 @@ namespace mdd {
 				}
 			}
 			
- 			++gen;
-			check = gen < _min_generations || _bests[0].fitness > _max_fitness;
+ 			++generation;
+			check = generation < _min_generations || _bests[0].fitness > _max_fitness;
 			if (!check && _converges > 0)
 			{
 				bool found = false;
@@ -338,6 +387,20 @@ namespace mdd {
 			int p2 = (int)random_num(0, gen_pool.size() - 1);
 			_children.push_back(combine(gen_pool[p1], gen_pool[p2]));
 		}
+		json ret;
+		ret["generation"] = generation;
+		for (auto& ind : _children)
+		{
+			json jind;
+			jind["fitness"] = ind.fitness;
+			jind["dna"] = ind.dna;
+			ret["individuals"] = jind;
+		}
+		json jcall;
+		jcall["operation"] = "add";
+		jcall["args"]["subject"] = getIdentifier();
+		jcall["args"]["object"] = ret;
+		_callback(jcall);
 		evaluate(init_len);
 	}
 

+ 18 - 10
server/src/main.cpp

@@ -2,7 +2,6 @@
 #if (defined (WIN32))
     #include <Windows.h>
 #endif
-#include <Registration.h>
 #include <json.hpp>
 #include <zmq.hpp>
 #include "zhelpers.hpp"
@@ -14,7 +13,6 @@ using namespace mdd;
 class Server {
 private:
     zmq::context_t context;
-    Registration regi = Registration();
     Connector _connector = Connector();
 
     json msg_header() {
@@ -75,6 +73,22 @@ private:
             jmsg["change"] = _connector.decode(request);
         }
 
+        sendChange(jmsg);
+        
+        //auto end = std::chrono::steady_clock::now();
+        //std::chrono::duration<double> elapsed_seconds = middle - start;
+        //std::chrono::duration<double> elapsed_seconds2 = end - middle;
+        //std::cout << "[Server]: 1:" << elapsed_seconds.count() << " s 2:" << elapsed_seconds2.count() << "s\n";
+    }
+
+    json sendChange(const json& jchange) {
+        json jmsg = jchange;
+        if (!jmsg.contains("change"))
+        {
+            jmsg = msg_header();
+            jmsg["change"] = jchange;
+        }
+        
         std::string smsg = jmsg.dump();
         std::string stopic = "CHANGE";
         zmq::message_t topic(stopic.size());
@@ -92,15 +106,8 @@ private:
         catch (zmq::error_t& e) {
             std::cout << e.what() << std::endl;
         }
-        //msg.rebuild(3);
-        //topic.rebuild(4);
-
         ++counter;
-        
-        //auto end = std::chrono::steady_clock::now();
-        //std::chrono::duration<double> elapsed_seconds = middle - start;
-        //std::chrono::duration<double> elapsed_seconds2 = end - middle;
-        //std::cout << "[Server]: 1:" << elapsed_seconds.count() << " s 2:" << elapsed_seconds2.count() << "s\n";
+        return jmsg;
     }
 
 public:
@@ -118,6 +125,7 @@ public:
         id = reinterpret_cast<uint32_t>(this);
         reply_socket.bind("tcp://*:5555");
         publisher_socket.bind("tcp://*:5556");
+        _connector.attachCallback(std::bind(&Server::sendChange,this,std::placeholders::_1));
     }
 
     void handle_request() {