Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
monticore
EmbeddedMontiArc
generators
CNNArch2Gluon
Commits
080c4249
Commit
080c4249
authored
Jul 14, 2019
by
Nicola Gatto
Browse files
Adapt new wrapper interface
parent
e680d6bb
Changes
8
Hide whitespace changes
Inline
Side-by-side
pom.xml
View file @
080c4249
...
...
@@ -19,7 +19,7 @@
<CNNTrain.version>
0.3.4-SNAPSHOT
</CNNTrain.version>
<CNNArch2X.version>
0.0.2-SNAPSHOT
</CNNArch2X.version>
<embedded-montiarc-math-opt-generator>
0.1.4
</embedded-montiarc-math-opt-generator>
<EMADL2PythonWrapper.version>
0.0.
1
</EMADL2PythonWrapper.version>
<EMADL2PythonWrapper.version>
0.0.
2-SNAPSHOT
</EMADL2PythonWrapper.version>
<!-- .. Libraries .................................................. -->
<guava.version>
18.0
</guava.version>
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java
View file @
080c4249
package
de.monticore.lang.monticar.cnnarch.gluongenerator
;
import
com.google.common.collect.Maps
;
import
de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol
;
import
de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic.CriticNetworkGenerationPair
;
import
de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic.CriticNetworkGenerator
;
import
de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.FunctionParameterChecker
;
...
...
@@ -164,7 +165,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
setRootProjectModelsDir
(
modelsDirPath
.
toString
());
}
rewardFunctionSourceGenerator
.
generate
(
getRootProjectModelsDir
().
get
(),
EMAComponentInstanceSymbol
emaSymbol
=
rewardFunctionSourceGenerator
.
generate
(
getRootProjectModelsDir
().
get
(),
rewardFunctionRootModel
,
rewardFunctionOutputPath
);
fixArmadilloEmamGenerationOfFile
(
Paths
.
get
(
rewardFunctionOutputPath
,
String
.
join
(
"_"
,
fullNameOfComponent
)
+
".h"
));
...
...
@@ -175,12 +176,11 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
if
(
pythonWrapperApi
.
checkIfPythonModuleBuildAvailable
())
{
final
String
rewardModuleOutput
=
Paths
.
get
(
getGenerationTargetPath
(),
REINFORCEMENT_LEARNING_FRAMEWORK_MODULE
).
toString
();
componentPortInformation
=
pythonWrapperApi
.
generateAndTryBuilding
(
getRootProjectModelsDir
().
get
()
,
rewardFunctionRootModel
,
pythonWrapperOutputPath
,
rewardModuleOutput
);
componentPortInformation
=
pythonWrapperApi
.
generateAndTryBuilding
(
emaSymbol
,
pythonWrapperOutputPath
,
rewardModuleOutput
);
}
else
{
Log
.
warn
(
"Cannot build wrapper automatically: OS not supported. Please build manually before starting training."
);
componentPortInformation
=
pythonWrapperApi
.
generate
(
getRootProjectModelsDir
().
get
(),
rewardFunctionRootModel
,
pythonWrapperOutputPath
);
componentPortInformation
=
pythonWrapperApi
.
generate
(
emaSymbol
,
pythonWrapperOutputPath
);
}
RewardFunctionParameterAdapter
functionParameter
=
new
RewardFunctionParameterAdapter
(
componentPortInformation
);
new
FunctionParameterChecker
().
check
(
functionParameter
);
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/RewardFunctionSourceGenerator.java
View file @
080c4249
package
de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement
;
import
de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol
;
/**
*
*/
public
interface
RewardFunctionSourceGenerator
{
void
generate
(
String
modelPath
,
String
qualifiedName
,
String
targetPath
);
EMAComponentInstanceSymbol
generate
(
String
modelPath
,
String
qualifiedName
,
String
targetPath
);
}
\ No newline at end of file
src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/IntegrationPythonWrapperTest.java
deleted
100644 → 0
View file @
e680d6bb
package
de.monticore.lang.monticar.cnnarch.gluongenerator
;
import
de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator
;
import
de.monticore.lang.monticar.cnnarch.gluongenerator.util.TrainedArchitectureMockFactory
;
import
de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture
;
import
de.se_rwth.commons.logging.Finding
;
import
de.se_rwth.commons.logging.Log
;
import
org.junit.Before
;
import
org.junit.Ignore
;
import
org.junit.Test
;
import
java.nio.file.Path
;
import
java.nio.file.Paths
;
import
java.util.Arrays
;
import
java.util.stream.Collectors
;
import
static
junit
.
framework
.
TestCase
.
assertTrue
;
import
static
org
.
mockito
.
Mockito
.
mock
;
public
class
IntegrationPythonWrapperTest
extends
AbstractSymtabTest
{
private
RewardFunctionSourceGenerator
rewardFunctionSourceGenerator
;
@Before
public
void
setUp
()
{
// ensure an empty log
Log
.
getFindings
().
clear
();
Log
.
enableFailQuick
(
false
);
rewardFunctionSourceGenerator
=
mock
(
RewardFunctionSourceGenerator
.
class
);
}
@Test
public
void
testReinforcementConfigWithRewardGeneration
()
{
Log
.
getFindings
().
clear
();
Path
modelPath
=
Paths
.
get
(
"src/test/resources/valid_tests"
);
CNNTrain2Gluon
trainGenerator
=
new
CNNTrain2Gluon
(
rewardFunctionSourceGenerator
);
TrainedArchitecture
trainedArchitecture
=
TrainedArchitectureMockFactory
.
createTrainedArchitectureMock
();
trainGenerator
.
generate
(
modelPath
,
"ReinforcementConfig1"
,
trainedArchitecture
);
assertTrue
(
Log
.
getFindings
().
stream
().
filter
(
Finding:
:
isError
).
collect
(
Collectors
.
toList
()).
isEmpty
());
checkFilesAreEqual
(
Paths
.
get
(
"./target/generated-sources-cnnarch"
),
Paths
.
get
(
"./src/test/resources/target_code/ReinforcementConfig1"
),
Arrays
.
asList
(
"CNNTrainer_reinforcementConfig1.py"
,
"start_training.sh"
,
"reinforcement_learning/__init__.py"
,
"reinforcement_learning/strategy.py"
,
"reinforcement_learning/agent.py"
,
"reinforcement_learning/environment.py"
,
"reinforcement_learning/replay_memory.py"
,
"reinforcement_learning/util.py"
,
"reinforcement_learning/cnnarch_logger.py"
)
);
assertTrue
(
Paths
.
get
(
"./target/generated-sources-cnnarch/reward/pylib"
).
toFile
().
isDirectory
());
}
}
src/test/resources/target_code/ros-ddpg/CNNTrainer_rosActorNetwork.py
View file @
080c4249
...
...
@@ -49,6 +49,7 @@ if __name__ == "__main__":
'state_topic'
:
'/environment/state'
,
'action_topic'
:
'/environment/action'
,
'reset_topic'
:
'/environment/reset'
,
'reward_topic'
:
'/environment/reward'
,
}
env
=
reinforcement_learning
.
environment
.
RosEnvironment
(
**
env_params
)
...
...
src/test/resources/target_code/ros-ddpg/reinforcement_learning/environment.py
View file @
080c4249
...
...
@@ -2,29 +2,12 @@ import abc
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logger
=
logging
.
getLogger
(
__name__
)
import
reward_rewardFunction_executor
class
RewardFunction
(
object
):
def
__init__
(
self
):
self
.
__reward_wrapper
=
reward_rewardFunction_executor
.
reward_rewardFunction_executor
()
self
.
__reward_wrapper
.
init
()
def
reward
(
self
,
state
,
terminal
):
s
=
state
.
astype
(
'double'
)
t
=
bool
(
terminal
)
inp
=
reward_rewardFunction_executor
.
reward_rewardFunction_input
()
inp
.
state
=
s
inp
.
isTerminal
=
t
output
=
self
.
__reward_wrapper
.
execute
(
inp
)
return
output
.
reward
class
Environment
:
__metaclass__
=
abc
.
ABCMeta
def
__init__
(
self
):
self
.
_reward_function
=
RewardFunction
()
pass
@
abc
.
abstractmethod
def
reset
(
self
):
...
...
@@ -60,6 +43,8 @@ class RosEnvironment(Environment):
self
.
__waiting_for_terminal_update
=
False
self
.
__last_received_state
=
0
self
.
__last_received_terminal
=
True
self
.
__last_received_reward
=
0.0
self
.
__waiting_for_reward_update
=
False
rospy
.
loginfo
(
"Initialize node {0}"
.
format
(
ros_node_name
))
...
...
@@ -77,6 +62,9 @@ class RosEnvironment(Environment):
self
.
__terminal_state_subscriber
=
rospy
.
Subscriber
(
terminal_state_topic
,
Bool
,
self
.
__terminal_state_callback
)
rospy
.
loginfo
(
'Terminal State Subscriber registered with topic {}'
.
format
(
terminal_state_topic
))
self
.
__reward_subscriber
=
rospy
.
Subscriber
(
reward_topic
,
Float32
,
self
.
__reward_callback
)
rospy
.
loginfo
(
'Reward Subscriber registered with topic {}'
.
format
(
reward_topic
))
rate
=
rospy
.
Rate
(
10
)
thread
.
start_new_thread
(
rospy
.
spin
,
())
...
...
@@ -110,11 +98,12 @@ class RosEnvironment(Environment):
self
.
__waiting_for_state_update
=
True
self
.
__waiting_for_terminal_update
=
True
self
.
__waiting_for_reward_update
=
True
self
.
__step_publisher
.
publish
(
action_rospy
)
self
.
__wait_for_new_state
(
self
.
__step_publisher
,
action_rospy
)
next_state
=
self
.
__last_received_state
terminal
=
self
.
__last_received_terminal
reward
=
self
.
__
calc_reward
(
next_state
,
terminal
)
reward
=
self
.
__
last_received_reward
rospy
.
logdebug
(
'Calculated reward: {}'
.
format
(
reward
))
return
next_state
,
reward
,
terminal
,
0
...
...
@@ -123,7 +112,7 @@ class RosEnvironment(Environment):
time_of_timeout
=
time
.
time
()
+
self
.
__timeout_in_s
timeout_counter
=
0
while
(
self
.
__waiting_for_state_update
or
self
.
__waiting_for_terminal_update
):
or
self
.
__waiting_for_terminal_update
or
self
.
__waiting_for_reward_update
):
is_timeout
=
(
time
.
time
()
>
time_of_timeout
)
if
(
is_timeout
):
if
timeout_counter
<
3
:
...
...
@@ -150,6 +139,7 @@ class RosEnvironment(Environment):
logger
.
debug
(
'Received terminal: {}'
.
format
(
self
.
__last_received_terminal
))
self
.
__waiting_for_terminal_update
=
False
def
__calc_reward
(
self
,
state
,
terminal
):
# C++ Wrapper call
return
self
.
_reward_function
.
reward
(
state
,
terminal
)
def
__reward_callback
(
self
,
data
):
self
.
__last_received_reward
=
float
(
data
.
data
)
logger
.
debug
(
'Received reward: {}'
.
format
(
self
.
__last_received_reward
))
self
.
__waiting_for_reward_update
=
False
src/test/resources/valid_tests/ddpg-ros/RosActorNetwork.cnnt
View file @
080c4249
...
...
@@ -9,10 +9,9 @@ configuration RosActorNetwork {
state_topic : "/environment/state"
action_topic : "/environment/action"
reset_topic : "/environment/reset"
reward_topic: "/environment/reward"
}
reward_function : reward.rewardFunction
agent_name : "ddpg-agent"
num_episodes : 2500
...
...
src/test/resources/valid_tests/ddpg-ros/reward/RewardFunction.emadl
deleted
100644 → 0
View file @
e680d6bb
package
reward
;
component
RewardFunction
{
ports
in
Q
^{
16
}
state
,
in
B
isTerminal
,
out
Q
reward
;
implementation
Math
{
Q
speed
=
state
(
15
);
Q
angle
=
state
(
1
);
reward
=
speed
*
cos
(
angle
);
}
}
\ No newline at end of file
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment