Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
C
CNNArch2Gluon
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
1
Issues
1
List
Boards
Labels
Service Desk
Milestones
Iterations
Merge Requests
0
Merge Requests
0
Requirements
Requirements
List
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Test Cases
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Operations
Operations
Incidents
Environments
Packages & Registries
Packages & Registries
Package Registry
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issue
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
monticore
EmbeddedMontiArc
generators
CNNArch2Gluon
Commits
7cd16d16
Commit
7cd16d16
authored
Oct 07, 2019
by
Nicola Gatto
Committed by
Evgeny Kusmenko
Oct 07, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Pipeline fix and new CNNTrain Integration
parent
7aeb41ef
Changes
83
Hide whitespace changes
Inline
Side-by-side
Showing
83 changed files
with
271 additions
and
93 deletions
+271
-93
pom.xml
pom.xml
+2
-2
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java
.../lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java
+11
-3
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/FunctionParameterChecker.java
...luongenerator/reinforcement/FunctionParameterChecker.java
+26
-2
src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/FunctionParameterCheckerTest.java
...generator/reinforcement/FunctionParameterCheckerTest.java
+215
-0
src/test/resources/target_code/CNNBufferFile.h
src/test/resources/target_code/CNNBufferFile.h
+0
-1
src/test/resources/target_code/CNNCreator_Alexnet.py
src/test/resources/target_code/CNNCreator_Alexnet.py
+0
-1
src/test/resources/target_code/CNNCreator_CifarClassifierNetwork.py
...esources/target_code/CNNCreator_CifarClassifierNetwork.py
+0
-1
src/test/resources/target_code/CNNCreator_VGG16.py
src/test/resources/target_code/CNNCreator_VGG16.py
+0
-1
src/test/resources/target_code/CNNDataLoader_Alexnet.py
src/test/resources/target_code/CNNDataLoader_Alexnet.py
+0
-1
src/test/resources/target_code/CNNDataLoader_CifarClassifierNetwork.py
...urces/target_code/CNNDataLoader_CifarClassifierNetwork.py
+0
-1
src/test/resources/target_code/CNNDataLoader_VGG16.py
src/test/resources/target_code/CNNDataLoader_VGG16.py
+0
-1
src/test/resources/target_code/CNNNet_Alexnet.py
src/test/resources/target_code/CNNNet_Alexnet.py
+16
-7
src/test/resources/target_code/CNNNet_CifarClassifierNetwork.py
...st/resources/target_code/CNNNet_CifarClassifierNetwork.py
+0
-1
src/test/resources/target_code/CNNNet_VGG16.py
src/test/resources/target_code/CNNNet_VGG16.py
+0
-1
src/test/resources/target_code/CNNPredictor_Alexnet.h
src/test/resources/target_code/CNNPredictor_Alexnet.h
+0
-1
src/test/resources/target_code/CNNPredictor_CifarClassifierNetwork.h
...sources/target_code/CNNPredictor_CifarClassifierNetwork.h
+0
-1
src/test/resources/target_code/CNNPredictor_VGG16.h
src/test/resources/target_code/CNNPredictor_VGG16.h
+0
-1
src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py
...est/resources/target_code/CNNSupervisedTrainer_Alexnet.py
+0
-1
src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py
...arget_code/CNNSupervisedTrainer_CifarClassifierNetwork.py
+0
-1
src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py
src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py
+0
-1
src/test/resources/target_code/CNNTrainer_emptyConfig.py
src/test/resources/target_code/CNNTrainer_emptyConfig.py
+0
-1
src/test/resources/target_code/CNNTrainer_fullConfig.py
src/test/resources/target_code/CNNTrainer_fullConfig.py
+0
-1
src/test/resources/target_code/CNNTrainer_simpleConfig.py
src/test/resources/target_code/CNNTrainer_simpleConfig.py
+0
-1
src/test/resources/target_code/ReinforcementConfig1/CNNTrainer_reinforcementConfig1.py
...e/ReinforcementConfig1/CNNTrainer_reinforcementConfig1.py
+0
-1
src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/agent.py
...code/ReinforcementConfig1/reinforcement_learning/agent.py
+0
-1
src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/cnnarch_logger.py
...forcementConfig1/reinforcement_learning/cnnarch_logger.py
+0
-1
src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/environment.py
...einforcementConfig1/reinforcement_learning/environment.py
+0
-1
src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/replay_memory.py
...nforcementConfig1/reinforcement_learning/replay_memory.py
+0
-1
src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/strategy.py
...e/ReinforcementConfig1/reinforcement_learning/strategy.py
+0
-1
src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/util.py
..._code/ReinforcementConfig1/reinforcement_learning/util.py
+0
-1
src/test/resources/target_code/ReinforcementConfig1/start_training.sh
...ources/target_code/ReinforcementConfig1/start_training.sh
+1
-2
src/test/resources/target_code/ReinforcementConfig2/CNNTrainer_reinforcementConfig2.py
...e/ReinforcementConfig2/CNNTrainer_reinforcementConfig2.py
+0
-1
src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/agent.py
...code/ReinforcementConfig2/reinforcement_learning/agent.py
+0
-1
src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/cnnarch_logger.py
...forcementConfig2/reinforcement_learning/cnnarch_logger.py
+0
-1
src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/environment.py
...einforcementConfig2/reinforcement_learning/environment.py
+0
-1
src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/replay_memory.py
...nforcementConfig2/reinforcement_learning/replay_memory.py
+0
-1
src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/strategy.py
...e/ReinforcementConfig2/reinforcement_learning/strategy.py
+0
-1
src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/util.py
..._code/ReinforcementConfig2/reinforcement_learning/util.py
+0
-1
src/test/resources/target_code/ReinforcementConfig2/start_training.sh
...ources/target_code/ReinforcementConfig2/start_training.sh
+0
-1
src/test/resources/target_code/ReinforcementConfig3/CNNTrainer_reinforcementConfig3.py
...e/ReinforcementConfig3/CNNTrainer_reinforcementConfig3.py
+0
-1
src/test/resources/target_code/ReinforcementConfig3/reinforcement_learning/agent.py
...code/ReinforcementConfig3/reinforcement_learning/agent.py
+0
-1
src/test/resources/target_code/ReinforcementConfig3/reinforcement_learning/cnnarch_logger.py
...forcementConfig3/reinforcement_learning/cnnarch_logger.py
+0
-1
src/test/resources/target_code/ReinforcementConfig3/reinforcement_learning/environment.py
...einforcementConfig3/reinforcement_learning/environment.py
+0
-1
src/test/resources/target_code/ReinforcementConfig3/reinforcement_learning/replay_memory.py
...nforcementConfig3/reinforcement_learning/replay_memory.py
+0
-1
src/test/resources/target_code/ReinforcementConfig3/reinforcement_learning/strategy.py
...e/ReinforcementConfig3/reinforcement_learning/strategy.py
+0
-1
src/test/resources/target_code/ReinforcementConfig3/reinforcement_learning/util.py
..._code/ReinforcementConfig3/reinforcement_learning/util.py
+0
-1
src/test/resources/target_code/ReinforcementConfig3/start_training.sh
...ources/target_code/ReinforcementConfig3/start_training.sh
+0
-1
src/test/resources/target_code/cmake/FindArmadillo.cmake
src/test/resources/target_code/cmake/FindArmadillo.cmake
+0
-1
src/test/resources/target_code/ddpg/CNNTrainer_actorNetwork.py
...est/resources/target_code/ddpg/CNNTrainer_actorNetwork.py
+0
-1
src/test/resources/target_code/ddpg/reinforcement_learning/CNNCreator_CriticNetwork.py
...e/ddpg/reinforcement_learning/CNNCreator_CriticNetwork.py
+0
-1
src/test/resources/target_code/ddpg/reinforcement_learning/CNNNet_CriticNetwork.py
..._code/ddpg/reinforcement_learning/CNNNet_CriticNetwork.py
+0
-1
src/test/resources/target_code/ddpg/reinforcement_learning/agent.py
...esources/target_code/ddpg/reinforcement_learning/agent.py
+0
-1
src/test/resources/target_code/ddpg/reinforcement_learning/cnnarch_logger.py
...target_code/ddpg/reinforcement_learning/cnnarch_logger.py
+0
-1
src/test/resources/target_code/ddpg/reinforcement_learning/environment.py
...es/target_code/ddpg/reinforcement_learning/environment.py
+0
-1
src/test/resources/target_code/ddpg/reinforcement_learning/replay_memory.py
.../target_code/ddpg/reinforcement_learning/replay_memory.py
+0
-1
src/test/resources/target_code/ddpg/reinforcement_learning/strategy.py
...urces/target_code/ddpg/reinforcement_learning/strategy.py
+0
-1
src/test/resources/target_code/ddpg/reinforcement_learning/util.py
...resources/target_code/ddpg/reinforcement_learning/util.py
+0
-1
src/test/resources/target_code/ddpg/start_training.sh
src/test/resources/target_code/ddpg/start_training.sh
+0
-1
src/test/resources/target_code/reinforcement_learning/agent.py
...est/resources/target_code/reinforcement_learning/agent.py
+0
-1
src/test/resources/target_code/reinforcement_learning/environment.py
...sources/target_code/reinforcement_learning/environment.py
+0
-1
src/test/resources/target_code/reinforcement_learning/replay_memory.py
...urces/target_code/reinforcement_learning/replay_memory.py
+0
-1
src/test/resources/target_code/reinforcement_learning/strategy.py
.../resources/target_code/reinforcement_learning/strategy.py
+0
-1
src/test/resources/target_code/reinforcement_learning/util.py
...test/resources/target_code/reinforcement_learning/util.py
+0
-1
src/test/resources/target_code/ros-ddpg/CNNTrainer_rosActorNetwork.py
...ources/target_code/ros-ddpg/CNNTrainer_rosActorNetwork.py
+0
-1
src/test/resources/target_code/ros-ddpg/reinforcement_learning/CNNCreator_RosCriticNetwork.py
...dpg/reinforcement_learning/CNNCreator_RosCriticNetwork.py
+0
-1
src/test/resources/target_code/ros-ddpg/reinforcement_learning/CNNNet_RosCriticNetwork.py
...os-ddpg/reinforcement_learning/CNNNet_RosCriticNetwork.py
+0
-1
src/test/resources/target_code/ros-ddpg/reinforcement_learning/agent.py
...rces/target_code/ros-ddpg/reinforcement_learning/agent.py
+0
-1
src/test/resources/target_code/ros-ddpg/reinforcement_learning/cnnarch_logger.py
...et_code/ros-ddpg/reinforcement_learning/cnnarch_logger.py
+0
-1
src/test/resources/target_code/ros-ddpg/reinforcement_learning/environment.py
...arget_code/ros-ddpg/reinforcement_learning/environment.py
+0
-1
src/test/resources/target_code/ros-ddpg/reinforcement_learning/replay_memory.py
...get_code/ros-ddpg/reinforcement_learning/replay_memory.py
+0
-1
src/test/resources/target_code/ros-ddpg/reinforcement_learning/strategy.py
...s/target_code/ros-ddpg/reinforcement_learning/strategy.py
+0
-1
src/test/resources/target_code/ros-ddpg/reinforcement_learning/util.py
...urces/target_code/ros-ddpg/reinforcement_learning/util.py
+0
-1
src/test/resources/target_code/ros-ddpg/start_training.sh
src/test/resources/target_code/ros-ddpg/start_training.sh
+0
-1
src/test/resources/target_code/td3/CNNTrainer_tD3Config.py
src/test/resources/target_code/td3/CNNTrainer_tD3Config.py
+0
-1
src/test/resources/target_code/td3/reinforcement_learning/CNNCreator_CriticNetwork.py
...de/td3/reinforcement_learning/CNNCreator_CriticNetwork.py
+0
-1
src/test/resources/target_code/td3/reinforcement_learning/CNNNet_CriticNetwork.py
...t_code/td3/reinforcement_learning/CNNNet_CriticNetwork.py
+0
-1
src/test/resources/target_code/td3/reinforcement_learning/agent.py
...resources/target_code/td3/reinforcement_learning/agent.py
+0
-1
src/test/resources/target_code/td3/reinforcement_learning/cnnarch_logger.py
.../target_code/td3/reinforcement_learning/cnnarch_logger.py
+0
-1
src/test/resources/target_code/td3/reinforcement_learning/environment.py
...ces/target_code/td3/reinforcement_learning/environment.py
+0
-1
src/test/resources/target_code/td3/reinforcement_learning/replay_memory.py
...s/target_code/td3/reinforcement_learning/replay_memory.py
+0
-1
src/test/resources/target_code/td3/reinforcement_learning/strategy.py
...ources/target_code/td3/reinforcement_learning/strategy.py
+0
-1
src/test/resources/target_code/td3/reinforcement_learning/util.py
.../resources/target_code/td3/reinforcement_learning/util.py
+0
-1
src/test/resources/target_code/td3/start_training.sh
src/test/resources/target_code/td3/start_training.sh
+0
-1
No files found.
pom.xml
View file @
7cd16d16
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
<groupId>
de.monticore.lang.monticar
</groupId>
<groupId>
de.monticore.lang.monticar
</groupId>
<artifactId>
cnnarch-gluon-generator
</artifactId>
<artifactId>
cnnarch-gluon-generator
</artifactId>
<version>
0.2.
8
-SNAPSHOT
</version>
<version>
0.2.
9
-SNAPSHOT
</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
<!-- == PROJECT DEPENDENCIES ============================================= -->
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
<!-- .. SE-Libraries .................................................. -->
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>
0.3.3-SNAPSHOT
</CNNArch.version>
<CNNArch.version>
0.3.3-SNAPSHOT
</CNNArch.version>
<CNNTrain.version>
0.3.
6
-SNAPSHOT
</CNNTrain.version>
<CNNTrain.version>
0.3.
7
-SNAPSHOT
</CNNTrain.version>
<CNNArch2X.version>
0.0.4-SNAPSHOT
</CNNArch2X.version>
<CNNArch2X.version>
0.0.4-SNAPSHOT
</CNNArch2X.version>
<embedded-montiarc-math-opt-generator>
0.1.4
</embedded-montiarc-math-opt-generator>
<embedded-montiarc-math-opt-generator>
0.1.4
</embedded-montiarc-math-opt-generator>
<EMADL2PythonWrapper.version>
0.0.2-SNAPSHOT
</EMADL2PythonWrapper.version>
<EMADL2PythonWrapper.version>
0.0.2-SNAPSHOT
</EMADL2PythonWrapper.version>
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java
View file @
7cd16d16
...
@@ -151,7 +151,14 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
...
@@ -151,7 +151,14 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
// Generate Reward function if necessary
// Generate Reward function if necessary
if
(
configuration
.
getRlRewardFunction
().
isPresent
())
{
if
(
configuration
.
getRlRewardFunction
().
isPresent
())
{
generateRewardFunction
(
configuration
.
getRlRewardFunction
().
get
(),
Paths
.
get
(
rootProjectModelsDir
));
if
(
configuration
.
getTrainedArchitecture
().
isPresent
())
{
generateRewardFunction
(
configuration
.
getTrainedArchitecture
().
get
(),
configuration
.
getRlRewardFunction
().
get
(),
Paths
.
get
(
rootProjectModelsDir
));
}
else
{
Log
.
error
(
"No architecture model for the trained neural network but is required for "
+
"reinforcement learning configuration."
);
}
}
}
ftlContext
.
put
(
"trainerName"
,
trainerName
);
ftlContext
.
put
(
"trainerName"
,
trainerName
);
...
@@ -167,7 +174,8 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
...
@@ -167,7 +174,8 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
return
fileContentMap
;
return
fileContentMap
;
}
}
private
void
generateRewardFunction
(
RewardFunctionSymbol
rewardFunctionSymbol
,
Path
modelsDirPath
)
{
private
void
generateRewardFunction
(
NNArchitectureSymbol
trainedArchitecture
,
RewardFunctionSymbol
rewardFunctionSymbol
,
Path
modelsDirPath
)
{
GeneratorPythonWrapperStandaloneApi
pythonWrapperApi
=
new
GeneratorPythonWrapperStandaloneApi
();
GeneratorPythonWrapperStandaloneApi
pythonWrapperApi
=
new
GeneratorPythonWrapperStandaloneApi
();
List
<
String
>
fullNameOfComponent
=
rewardFunctionSymbol
.
getRewardFunctionComponentName
();
List
<
String
>
fullNameOfComponent
=
rewardFunctionSymbol
.
getRewardFunctionComponentName
();
...
@@ -200,7 +208,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
...
@@ -200,7 +208,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
componentPortInformation
=
pythonWrapperApi
.
generate
(
emaSymbol
,
pythonWrapperOutputPath
);
componentPortInformation
=
pythonWrapperApi
.
generate
(
emaSymbol
,
pythonWrapperOutputPath
);
}
}
RewardFunctionParameterAdapter
functionParameter
=
new
RewardFunctionParameterAdapter
(
componentPortInformation
);
RewardFunctionParameterAdapter
functionParameter
=
new
RewardFunctionParameterAdapter
(
componentPortInformation
);
new
FunctionParameterChecker
().
check
(
functionParameter
);
new
FunctionParameterChecker
().
check
(
functionParameter
,
trainedArchitecture
);
rewardFunctionSymbol
.
setRewardFunctionParameter
(
functionParameter
);
rewardFunctionSymbol
.
setRewardFunctionParameter
(
functionParameter
);
}
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/FunctionParameterChecker.java
View file @
7cd16d16
/* (c) https://github.com/MontiCore/monticore */
/* (c) https://github.com/MontiCore/monticore */
package
de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement
;
package
de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement
;
import
de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol
;
import
de.se_rwth.commons.logging.Log
;
import
de.se_rwth.commons.logging.Log
;
import
java.util.List
;
/**
/**
*
*
*/
*/
...
@@ -11,21 +14,42 @@ public class FunctionParameterChecker {
...
@@ -11,21 +14,42 @@ public class FunctionParameterChecker {
private
String
inputTerminalParameterName
;
private
String
inputTerminalParameterName
;
private
String
outputParameterName
;
private
String
outputParameterName
;
private
RewardFunctionParameterAdapter
rewardFunctionParameter
;
private
RewardFunctionParameterAdapter
rewardFunctionParameter
;
private
NNArchitectureSymbol
trainedArchitecture
;
public
FunctionParameterChecker
()
{
public
FunctionParameterChecker
()
{
}
}
public
void
check
(
final
RewardFunctionParameterAdapter
rewardFunctionParameter
)
{
public
void
check
(
final
RewardFunctionParameterAdapter
rewardFunctionParameter
,
final
NNArchitectureSymbol
trainedArchitecture
)
{
assert
rewardFunctionParameter
!=
null
;
assert
trainedArchitecture
!=
null
;
this
.
rewardFunctionParameter
=
rewardFunctionParameter
;
this
.
rewardFunctionParameter
=
rewardFunctionParameter
;
this
.
trainedArchitecture
=
trainedArchitecture
;
retrieveParameterNames
();
retrieveParameterNames
();
checkHasExactlyTwoInputs
();
checkHasExactlyTwoInputs
();
checkHasExactlyOneOutput
();
checkHasExactlyOneOutput
();
checkHasStateAndTerminalInput
();
checkHasStateAndTerminalInput
();
checkInputStateDimension
();
checkInputTerminalTypeAndDimension
();
checkInputTerminalTypeAndDimension
();
checkStateDimensionEqualsTrainedArchitectureState
();
checkInputStateDimension
();
checkOutputDimension
();
checkOutputDimension
();
}
}
private
void
checkStateDimensionEqualsTrainedArchitectureState
()
{
failIfConditionFails
(
stateInputOfNNArchitectureIsEqualToRewardState
(),
"State dimension of trained architecture is not equal to reward state dimensions."
);
}
private
boolean
stateInputOfNNArchitectureIsEqualToRewardState
()
{
assert
trainedArchitecture
.
getInputs
().
size
()
==
1
:
"Trained architecture is not a policy network."
;
final
String
nnStateInputName
=
trainedArchitecture
.
getInputs
().
get
(
0
);
final
List
<
Integer
>
dimensions
=
trainedArchitecture
.
getDimensions
().
get
(
nnStateInputName
);
return
rewardFunctionParameter
.
getInputPortDimensionOfPort
(
inputStateParameterName
).
isPresent
()
&&
rewardFunctionParameter
.
getInputPortDimensionOfPort
(
inputStateParameterName
).
get
().
equals
(
dimensions
);
}
private
void
checkHasExactlyTwoInputs
()
{
private
void
checkHasExactlyTwoInputs
()
{
failIfConditionFails
(
functionHasTwoInputs
(),
"Reward function must have exactly two input parameters: "
failIfConditionFails
(
functionHasTwoInputs
(),
"Reward function must have exactly two input parameters: "
+
"One input needs to represents the environment's state and another input needs to be a "
+
"One input needs to represents the environment's state and another input needs to be a "
...
...
src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/FunctionParameterCheckerTest.java
0 → 100644
View file @
7cd16d16
package
de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement
;
import
com.google.common.collect.ImmutableMap
;
import
com.google.common.collect.Lists
;
import
de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol
;
import
de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.ComponentPortInformation
;
import
de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.EmadlType
;
import
de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.PortDirection
;
import
de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.PortVariable
;
import
de.se_rwth.commons.logging.Finding
;
import
de.se_rwth.commons.logging.Log
;
import
org.junit.Before
;
import
org.junit.Test
;
import
java.util.List
;
import
java.util.Map
;
import
static
org
.
junit
.
Assert
.*;
import
static
org
.
mockito
.
Mockito
.
mock
;
import
static
org
.
mockito
.
Mockito
.
when
;
public
class
FunctionParameterCheckerTest
{
private
static
final
List
<
Integer
>
STATE_DIMENSIONS
=
Lists
.
newArrayList
(
3
,
2
,
4
);
private
static
final
PortVariable
STATE_PORT
=
PortVariable
.
multidimensionalVariableFrom
(
"input1"
,
EmadlType
.
Q
,
PortDirection
.
INPUT
,
STATE_DIMENSIONS
);
private
static
final
PortVariable
TERMINAL_PORT
=
PortVariable
.
primitiveVariableFrom
(
"input2"
,
EmadlType
.
B
,
PortDirection
.
INPUT
);
private
static
final
PortVariable
OUTPUT_PORT
=
PortVariable
.
primitiveVariableFrom
(
"output1"
,
EmadlType
.
Q
,
PortDirection
.
OUTPUT
);
private
static
final
String
COMPONENT_NAME
=
"TestRewardComponent"
;
FunctionParameterChecker
uut
=
new
FunctionParameterChecker
();
@Before
public
void
setup
()
{
Log
.
getFindings
().
clear
();
Log
.
enableFailQuick
(
false
);
}
@Test
public
void
validReward
()
{
// given
RewardFunctionParameterAdapter
adapter
=
getValidRewardAdapter
();
// when
uut
.
check
(
adapter
,
getValidTrainedArchitecture
());
List
<
Finding
>
findings
=
Log
.
getFindings
();
assertEquals
(
0
,
findings
.
stream
().
filter
(
Finding:
:
isError
).
count
());
}
@Test
public
void
invalidRewardWithOneInput
()
{
// given
RewardFunctionParameterAdapter
adapter
=
getComponentWithOneInput
();
// when
uut
.
check
(
adapter
,
getValidTrainedArchitecture
());
List
<
Finding
>
findings
=
Log
.
getFindings
();
assertTrue
(
findings
.
stream
().
anyMatch
(
Finding:
:
isError
));
}
@Test
public
void
invalidRewardWithTwoOutputs
()
{
// given
RewardFunctionParameterAdapter
adapter
=
getComponentWithTwoOutputs
();
// when
uut
.
check
(
adapter
,
getValidTrainedArchitecture
());
List
<
Finding
>
findings
=
Log
.
getFindings
();
assertTrue
(
findings
.
stream
().
anyMatch
(
Finding:
:
isError
));
}
@Test
public
void
invalidRewardWithTerminalHasQType
()
{
// given
RewardFunctionParameterAdapter
adapter
=
getComponentWithTwoQInputs
();
// when
uut
.
check
(
adapter
,
getValidTrainedArchitecture
());
List
<
Finding
>
findings
=
Log
.
getFindings
();
assertTrue
(
findings
.
stream
().
anyMatch
(
Finding:
:
isError
));
}
@Test
public
void
invalidRewardWithNonScalarOutput
()
{
// given
RewardFunctionParameterAdapter
adapter
=
getComponentWithNonScalarOutput
();
// when
uut
.
check
(
adapter
,
getValidTrainedArchitecture
());
List
<
Finding
>
findings
=
Log
.
getFindings
();
assertTrue
(
findings
.
stream
().
filter
(
Finding:
:
isError
).
count
()
==
1
);
}
@Test
public
void
invalidRewardStateUnequalToTrainedArchitectureState1
()
{
// given
RewardFunctionParameterAdapter
adapter
=
getValidRewardAdapter
();
NNArchitectureSymbol
trainedArchitectureWithDifferenDimension
=
getTrainedArchitectureWithStateDimensions
(
Lists
.
newArrayList
(
6
));
// when
uut
.
check
(
adapter
,
trainedArchitectureWithDifferenDimension
);
List
<
Finding
>
findings
=
Log
.
getFindings
();
assertTrue
(
findings
.
stream
().
filter
(
Finding:
:
isError
).
count
()
==
1
);
}
@Test
public
void
invalidRewardStateUnequalToTrainedArchitectureState2
()
{
// given
RewardFunctionParameterAdapter
adapter
=
getValidRewardAdapter
();
NNArchitectureSymbol
trainedArchitectureWithDifferenDimension
=
getTrainedArchitectureWithStateDimensions
(
Lists
.
newArrayList
(
3
,
8
));
// when
uut
.
check
(
adapter
,
trainedArchitectureWithDifferenDimension
);
List
<
Finding
>
findings
=
Log
.
getFindings
();
assertTrue
(
findings
.
stream
().
filter
(
Finding:
:
isError
).
count
()
==
1
);
}
@Test
public
void
invalidRewardStateUnequalToTrainedArchitectureState3
()
{
// given
RewardFunctionParameterAdapter
adapter
=
getValidRewardAdapter
();
NNArchitectureSymbol
trainedArchitectureWithDifferenDimension
=
getTrainedArchitectureWithStateDimensions
(
Lists
.
newArrayList
(
2
,
4
,
3
));
// when
uut
.
check
(
adapter
,
trainedArchitectureWithDifferenDimension
);
List
<
Finding
>
findings
=
Log
.
getFindings
();
assertTrue
(
findings
.
stream
().
filter
(
Finding:
:
isError
).
count
()
==
1
);
}
private
RewardFunctionParameterAdapter
getComponentWithNonScalarOutput
()
{
ComponentPortInformation
componentPortInformation
=
new
ComponentPortInformation
(
COMPONENT_NAME
);
componentPortInformation
.
addAllInputs
(
getValidInputPortVariables
());
List
<
PortVariable
>
outputs
=
Lists
.
newArrayList
(
PortVariable
.
multidimensionalVariableFrom
(
"output"
,
EmadlType
.
Q
,
PortDirection
.
OUTPUT
,
Lists
.
newArrayList
(
2
,
2
)));
componentPortInformation
.
addAllOutputs
(
outputs
);
return
new
RewardFunctionParameterAdapter
(
componentPortInformation
);
}
private
RewardFunctionParameterAdapter
getComponentWithTwoQInputs
()
{
ComponentPortInformation
componentPortInformation
=
new
ComponentPortInformation
(
COMPONENT_NAME
);
List
<
PortVariable
>
inputs
=
Lists
.
newArrayList
(
STATE_PORT
,
PortVariable
.
multidimensionalVariableFrom
(
"input2"
,
EmadlType
.
Q
,
PortDirection
.
INPUT
,
Lists
.
newArrayList
(
2
,
3
,
2
)));
componentPortInformation
.
addAllInputs
(
inputs
);
componentPortInformation
.
addAllOutputs
(
getValidOutputPorts
());
return
new
RewardFunctionParameterAdapter
(
componentPortInformation
);
}
private
RewardFunctionParameterAdapter
getComponentWithTwoOutputs
()
{
ComponentPortInformation
componentPortInformation
=
new
ComponentPortInformation
(
COMPONENT_NAME
);
componentPortInformation
.
addAllInputs
(
getValidInputPortVariables
());
List
<
PortVariable
>
outputs
=
getValidOutputPorts
();
outputs
.
add
(
PortVariable
.
primitiveVariableFrom
(
"output2"
,
EmadlType
.
B
,
PortDirection
.
OUTPUT
));
componentPortInformation
.
addAllOutputs
(
outputs
);
return
new
RewardFunctionParameterAdapter
(
componentPortInformation
);
}
private
RewardFunctionParameterAdapter
getComponentWithOneInput
()
{
ComponentPortInformation
componentPortInformation
=
new
ComponentPortInformation
(
COMPONENT_NAME
);
componentPortInformation
.
addAllInputs
(
Lists
.
newArrayList
(
STATE_PORT
));
componentPortInformation
.
addAllOutputs
(
getValidOutputPorts
());
return
new
RewardFunctionParameterAdapter
(
componentPortInformation
);
}
private
RewardFunctionParameterAdapter
getValidRewardAdapter
()
{
ComponentPortInformation
componentPortInformation
=
new
ComponentPortInformation
(
COMPONENT_NAME
);
componentPortInformation
.
addAllInputs
(
getValidInputPortVariables
());
componentPortInformation
.
addAllOutputs
(
getValidOutputPorts
());
return
new
RewardFunctionParameterAdapter
(
componentPortInformation
);
}
private
List
<
PortVariable
>
getValidOutputPorts
()
{
return
Lists
.
newArrayList
(
OUTPUT_PORT
);
}
private
List
<
PortVariable
>
getValidInputPortVariables
()
{
return
Lists
.
newArrayList
(
STATE_PORT
,
TERMINAL_PORT
);
}
private
NNArchitectureSymbol
getValidTrainedArchitecture
()
{
NNArchitectureSymbol
nnArchitectureSymbol
=
mock
(
NNArchitectureSymbol
.
class
);
final
String
stateInputName
=
"stateInput"
;
when
(
nnArchitectureSymbol
.
getInputs
()).
thenReturn
(
Lists
.
newArrayList
(
stateInputName
));
when
(
nnArchitectureSymbol
.
getDimensions
()).
thenReturn
(
ImmutableMap
.<
String
,
List
<
Integer
>>
builder
()
.
put
(
stateInputName
,
STATE_DIMENSIONS
)
.
build
());
return
nnArchitectureSymbol
;
}
private
NNArchitectureSymbol
getTrainedArchitectureWithStateDimensions
(
final
List
<
Integer
>
dimensions
)
{
NNArchitectureSymbol
nnArchitectureSymbol
=
mock
(
NNArchitectureSymbol
.
class
);
final
String
stateInputName
=
"stateInput"
;
when
(
nnArchitectureSymbol
.
getInputs
()).
thenReturn
(
Lists
.
newArrayList
(
stateInputName
));
when
(
nnArchitectureSymbol
.
getDimensions
()).
thenReturn
(
ImmutableMap
.<
String
,
List
<
Integer
>>
builder
()
.
put
(
stateInputName
,
dimensions
)
.
build
());
return
nnArchitectureSymbol
;
}
}
\ No newline at end of file
src/test/resources/target_code/CNNBufferFile.h
View file @
7cd16d16
/* (c) https://github.com/MontiCore/monticore */
#ifndef CNNBUFFERFILE_H
#ifndef CNNBUFFERFILE_H
#define CNNBUFFERFILE_H
#define CNNBUFFERFILE_H
...
...
src/test/resources/target_code/CNNCreator_Alexnet.py
View file @
7cd16d16
# (c) https://github.com/MontiCore/monticore
import
mxnet
as
mx
import
mxnet
as
mx
import
logging
import
logging
import
os
import
os
...
...
src/test/resources/target_code/CNNCreator_CifarClassifierNetwork.py
View file @
7cd16d16
# (c) https://github.com/MontiCore/monticore
import
mxnet
as
mx
import
mxnet
as
mx
import
logging
import
logging
import
os
import
os
...
...
src/test/resources/target_code/CNNCreator_VGG16.py
View file @
7cd16d16
# (c) https://github.com/MontiCore/monticore
import
mxnet
as
mx
import
mxnet
as
mx
import
logging
import
logging
import
os
import
os
...
...
src/test/resources/target_code/CNNDataLoader_Alexnet.py
View file @
7cd16d16
# (c) https://github.com/MontiCore/monticore
import
os
import
os
import
h5py
import
h5py
import
mxnet
as
mx
import
mxnet
as
mx
...
...
src/test/resources/target_code/CNNDataLoader_CifarClassifierNetwork.py
View file @
7cd16d16
# (c) https://github.com/MontiCore/monticore
import
os
import
os
import
h5py
import
h5py
import
mxnet
as
mx
import
mxnet
as
mx
...
...
src/test/resources/target_code/CNNDataLoader_VGG16.py
View file @
7cd16d16
# (c) https://github.com/MontiCore/monticore
import
os
import
os
import
h5py
import
h5py
import
mxnet
as
mx
import
mxnet
as
mx
...
...
src/test/resources/target_code/CNNNet_Alexnet.py
View file @
7cd16d16
# (c) https://github.com/MontiCore/monticore
import
mxnet
as
mx
import
mxnet
as
mx
import
numpy
as
np
import
numpy
as
np
from
mxnet
import
gluon
from
mxnet
import
gluon
...
@@ -91,13 +90,14 @@ class Net_0(gluon.HybridBlock):
...
@@ -91,13 +90,14 @@ class Net_0(gluon.HybridBlock):
else
:
else
:
self
.
input_normalization_data_
=
NoNormalization
()
self
.
input_normalization_data_
=
NoNormalization
()
self
.
conv1_padding
=
Padding
(
padding
=
(
0
,
0
,
0
,
0
,
2
,
1
,
2
,
1
))
self
.
conv1_padding
=
Padding
(
padding
=
(
0
,
0
,
-
1
,
0
,
0
,
0
,
0
,
0
))
self
.
conv1_
=
gluon
.
nn
.
Conv2D
(
channels
=
96
,
self
.
conv1_
=
gluon
.
nn
.
Conv2D
(
channels
=
96
,
kernel_size
=
(
11
,
11
),
kernel_size
=
(
11
,
11
),
strides
=
(
4
,
4
),
strides
=
(
4
,
4
),
use_bias
=
True
)
use_bias
=
True
)
# conv1_, output shape: {[96,55,55]}
# conv1_, output shape: {[96,55,55]}
self
.
pool1_padding
=
Padding
(
padding
=
(
0
,
0
,
-
1
,
0
,
0
,
0
,
0
,
0
))
self
.
pool1_
=
gluon
.
nn
.
MaxPool2D
(
self
.
pool1_
=
gluon
.
nn
.
MaxPool2D
(
pool_size
=
(
3
,
3
),
pool_size
=
(
3
,
3
),
strides
=
(
2
,
2
))
strides
=
(
2
,
2
))
...
@@ -114,6 +114,7 @@ class Net_0(gluon.HybridBlock):
...
@@ -114,6 +114,7 @@ class Net_0(gluon.HybridBlock):
use_bias
=
True
)
use_bias
=
True
)
# conv2_1_, output shape: {[128,27,27]}
# conv2_1_, output shape: {[128,27,27]}
self
.
pool2_1_padding
=
Padding
(
padding
=
(
0
,
0
,
-
1
,
0
,
0
,
0
,
0
,
0
))
self
.
pool2_1_
=
gluon
.
nn
.
MaxPool2D
(
self
.
pool2_1_
=
gluon
.
nn
.
MaxPool2D
(
pool_size
=
(
3
,
3
),
pool_size
=
(
3
,
3
),
strides
=
(
2
,
2
))
strides
=
(
2
,
2
))
...
@@ -127,6 +128,7 @@ class Net_0(gluon.HybridBlock):
...
@@ -127,6 +128,7 @@ class Net_0(gluon.HybridBlock):
use_bias
=
True
)
use_bias
=
True
)
# conv2_2_, output shape: {[128,27,27]}
# conv2_2_, output shape: {[128,27,27]}
self
.
pool2_2_padding
=
Padding
(
padding
=
(
0
,
0
,
-
1
,
0
,
0
,
0
,
0
,
0
))
self
.
pool2_2_
=
gluon
.
nn
.
MaxPool2D
(
self
.
pool2_2_
=
gluon
.
nn
.
MaxPool2D
(
pool_size
=
(
3
,
3
),
pool_size
=
(
3
,
3
),
strides
=
(
2
,
2
))
strides
=
(
2
,
2
))
...
@@ -162,6 +164,7 @@ class Net_0(gluon.HybridBlock):
...
@@ -162,6 +164,7 @@ class Net_0(gluon.HybridBlock):
use_bias
=
True
)
use_bias
=
True
)
# conv5_1_, output shape: {[128,13,13]}
# conv5_1_, output shape: {[128,13,13]}
self
.
pool5_1_padding
=
Padding
(
padding
=
(
0
,
0
,
-
1
,
0
,
0
,
0
,
0
,
0
))
self
.
pool5_1_
=
gluon
.
nn
.
MaxPool2D
(
self
.
pool5_1_
=
gluon
.
nn
.
MaxPool2D
(
pool_size
=
(
3
,
3
),
pool_size
=
(
3
,
3
),
strides
=
(
2
,
2
))
strides
=
(
2
,
2
))
...
@@ -183,6 +186,7 @@ class Net_0(gluon.HybridBlock):
...
@@ -183,6 +186,7 @@ class Net_0(gluon.HybridBlock):
use_bias
=
True
)
use_bias
=
True
)
# conv5_2_, output shape: {[128,13,13]}
# conv5_2_, output shape: {[128,13,13]}
self
.
pool5_2_padding
=
Padding
(
padding
=
(
0
,
0
,
-
1
,
0
,
0
,
0
,
0
,
0
))
self
.
pool5_2_
=
gluon
.
nn
.
MaxPool2D
(
self
.
pool5_2_
=
gluon
.
nn
.
MaxPool2D
(
pool_size
=
(
3
,
3
),
pool_size
=
(
3
,
3
),
strides
=
(
2
,
2
))
strides
=
(
2
,
2
))
...
@@ -217,7 +221,8 @@ class Net_0(gluon.HybridBlock):
...
@@ -217,7 +221,8 @@ class Net_0(gluon.HybridBlock):
beta
=
0.75
,
beta
=
0.75
,
knorm
=
2
,
knorm
=
2
,
nsize
=
5
)
nsize
=
5
)
pool1_
=
self
.
pool1_
(
lrn1_
)
pool1_padding
=
self
.
pool1_padding
(
lrn1_
)
pool1_
=
self
.
pool1_
(
pool1_padding
)
relu1_
=
self
.
relu1_
(
pool1_
)
relu1_
=
self
.
relu1_
(
pool1_
)
split1_
=
self
.
split1_
(
relu1_
)
split1_
=
self
.
split1_
(
relu1_
)
get2_1_
=
split1_
[
0
]
get2_1_
=
split1_
[
0
]
...
@@ -228,7 +233,8 @@ class Net_0(gluon.HybridBlock):
...
@@ -228,7 +233,8 @@ class Net_0(gluon.HybridBlock):
beta
=
0.75
,
beta
=
0.75
,
knorm
=
2
,
knorm
=
2
,
nsize
=
5
)
nsize
=
5
)
pool2_1_
=
self
.
pool2_1_
(
lrn2_1_
)