Commit 080c4249 authored by Nicola Gatto's avatar Nicola Gatto
Browse files

Adapt new wrapper interface

parent e680d6bb
......@@ -19,7 +19,7 @@
<CNNTrain.version>0.3.4-SNAPSHOT</CNNTrain.version>
<CNNArch2X.version>0.0.2-SNAPSHOT</CNNArch2X.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
<EMADL2PythonWrapper.version>0.0.1</EMADL2PythonWrapper.version>
<EMADL2PythonWrapper.version>0.0.2-SNAPSHOT</EMADL2PythonWrapper.version>
<!-- .. Libraries .................................................. -->
<guava.version>18.0</guava.version>
......
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import com.google.common.collect.Maps;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic.CriticNetworkGenerationPair;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic.CriticNetworkGenerator;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.FunctionParameterChecker;
......@@ -164,7 +165,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
setRootProjectModelsDir(modelsDirPath.toString());
}
rewardFunctionSourceGenerator.generate(getRootProjectModelsDir().get(),
EMAComponentInstanceSymbol emaSymbol = rewardFunctionSourceGenerator.generate(getRootProjectModelsDir().get(),
rewardFunctionRootModel, rewardFunctionOutputPath);
fixArmadilloEmamGenerationOfFile(Paths.get(rewardFunctionOutputPath, String.join("_", fullNameOfComponent) + ".h"));
......@@ -175,12 +176,11 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
if (pythonWrapperApi.checkIfPythonModuleBuildAvailable()) {
final String rewardModuleOutput
= Paths.get(getGenerationTargetPath(), REINFORCEMENT_LEARNING_FRAMEWORK_MODULE).toString();
componentPortInformation = pythonWrapperApi.generateAndTryBuilding(getRootProjectModelsDir().get(),
rewardFunctionRootModel, pythonWrapperOutputPath, rewardModuleOutput);
componentPortInformation = pythonWrapperApi.generateAndTryBuilding(emaSymbol,
pythonWrapperOutputPath, rewardModuleOutput);
} else {
Log.warn("Cannot build wrapper automatically: OS not supported. Please build manually before starting training.");
componentPortInformation = pythonWrapperApi.generate(getRootProjectModelsDir().get(), rewardFunctionRootModel,
pythonWrapperOutputPath);
componentPortInformation = pythonWrapperApi.generate(emaSymbol, pythonWrapperOutputPath);
}
RewardFunctionParameterAdapter functionParameter = new RewardFunctionParameterAdapter(componentPortInformation);
new FunctionParameterChecker().check(functionParameter);
......
package de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol;
/**
*
*/
public interface RewardFunctionSourceGenerator {
void generate(String modelPath, String qualifiedName, String targetPath);
EMAComponentInstanceSymbol generate(String modelPath, String qualifiedName, String targetPath);
}
\ No newline at end of file
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator;
import de.monticore.lang.monticar.cnnarch.gluongenerator.util.TrainedArchitectureMockFactory;
import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture;
import de.se_rwth.commons.logging.Finding;
import de.se_rwth.commons.logging.Log;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.stream.Collectors;
import static junit.framework.TestCase.assertTrue;
import static org.mockito.Mockito.mock;
public class IntegrationPythonWrapperTest extends AbstractSymtabTest{
private RewardFunctionSourceGenerator rewardFunctionSourceGenerator;
@Before
public void setUp() {
// ensure an empty log
Log.getFindings().clear();
Log.enableFailQuick(false);
rewardFunctionSourceGenerator = mock(RewardFunctionSourceGenerator.class);
}
@Test
public void testReinforcementConfigWithRewardGeneration() {
Log.getFindings().clear();
Path modelPath = Paths.get("src/test/resources/valid_tests");
CNNTrain2Gluon trainGenerator = new CNNTrain2Gluon(rewardFunctionSourceGenerator);
TrainedArchitecture trainedArchitecture = TrainedArchitectureMockFactory.createTrainedArchitectureMock();
trainGenerator.generate(modelPath, "ReinforcementConfig1", trainedArchitecture);
assertTrue(Log.getFindings().stream().filter(Finding::isError).collect(Collectors.toList()).isEmpty());
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code/ReinforcementConfig1"),
Arrays.asList(
"CNNTrainer_reinforcementConfig1.py",
"start_training.sh",
"reinforcement_learning/__init__.py",
"reinforcement_learning/strategy.py",
"reinforcement_learning/agent.py",
"reinforcement_learning/environment.py",
"reinforcement_learning/replay_memory.py",
"reinforcement_learning/util.py",
"reinforcement_learning/cnnarch_logger.py")
);
assertTrue(Paths.get("./target/generated-sources-cnnarch/reward/pylib").toFile().isDirectory());
}
}
......@@ -49,6 +49,7 @@ if __name__ == "__main__":
'state_topic': '/environment/state',
'action_topic': '/environment/action',
'reset_topic': '/environment/reset',
'reward_topic': '/environment/reward',
}
env = reinforcement_learning.environment.RosEnvironment(**env_params)
......
......@@ -2,29 +2,12 @@ import abc
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
import reward_rewardFunction_executor
class RewardFunction(object):
def __init__(self):
self.__reward_wrapper = reward_rewardFunction_executor.reward_rewardFunction_executor()
self.__reward_wrapper.init()
def reward(self, state, terminal):
s = state.astype('double')
t = bool(terminal)
inp = reward_rewardFunction_executor.reward_rewardFunction_input()
inp.state = s
inp.isTerminal = t
output = self.__reward_wrapper.execute(inp)
return output.reward
class Environment:
__metaclass__ = abc.ABCMeta
def __init__(self):
self._reward_function = RewardFunction()
pass
@abc.abstractmethod
def reset(self):
......@@ -60,6 +43,8 @@ class RosEnvironment(Environment):
self.__waiting_for_terminal_update = False
self.__last_received_state = 0
self.__last_received_terminal = True
self.__last_received_reward = 0.0
self.__waiting_for_reward_update = False
rospy.loginfo("Initialize node {0}".format(ros_node_name))
......@@ -77,6 +62,9 @@ class RosEnvironment(Environment):
self.__terminal_state_subscriber = rospy.Subscriber(terminal_state_topic, Bool, self.__terminal_state_callback)
rospy.loginfo('Terminal State Subscriber registered with topic {}'.format(terminal_state_topic))
self.__reward_subscriber = rospy.Subscriber(reward_topic, Float32, self.__reward_callback)
rospy.loginfo('Reward Subscriber registered with topic {}'.format(reward_topic))
rate = rospy.Rate(10)
thread.start_new_thread(rospy.spin, ())
......@@ -110,11 +98,12 @@ class RosEnvironment(Environment):
self.__waiting_for_state_update = True
self.__waiting_for_terminal_update = True
self.__waiting_for_reward_update = True
self.__step_publisher.publish(action_rospy)
self.__wait_for_new_state(self.__step_publisher, action_rospy)
next_state = self.__last_received_state
terminal = self.__last_received_terminal
reward = self.__calc_reward(next_state, terminal)
reward = self.__last_received_reward
rospy.logdebug('Calculated reward: {}'.format(reward))
return next_state, reward, terminal, 0
......@@ -123,7 +112,7 @@ class RosEnvironment(Environment):
time_of_timeout = time.time() + self.__timeout_in_s
timeout_counter = 0
while(self.__waiting_for_state_update
or self.__waiting_for_terminal_update):
or self.__waiting_for_terminal_update or self.__waiting_for_reward_update):
is_timeout = (time.time() > time_of_timeout)
if (is_timeout):
if timeout_counter < 3:
......@@ -150,6 +139,7 @@ class RosEnvironment(Environment):
logger.debug('Received terminal: {}'.format(self.__last_received_terminal))
self.__waiting_for_terminal_update = False
def __calc_reward(self, state, terminal):
# C++ Wrapper call
return self._reward_function.reward(state, terminal)
def __reward_callback(self, data):
self.__last_received_reward = float(data.data)
logger.debug('Received reward: {}'.format(self.__last_received_reward))
self.__waiting_for_reward_update = False
......@@ -9,10 +9,9 @@ configuration RosActorNetwork {
state_topic : "/environment/state"
action_topic : "/environment/action"
reset_topic : "/environment/reset"
reward_topic: "/environment/reward"
}
reward_function : reward.rewardFunction
agent_name : "ddpg-agent"
num_episodes : 2500
......
package reward;
component RewardFunction {
ports
in Q^{16} state,
in B isTerminal,
out Q reward;
implementation Math {
Q speed = state(15);
Q angle = state(1);
reward = speed * cos(angle);
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment