Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
C
CNNArch2Gluon
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
1
Issues
1
List
Boards
Labels
Service Desk
Milestones
Iterations
Merge Requests
0
Merge Requests
0
Requirements
Requirements
List
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Test Cases
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Operations
Operations
Incidents
Environments
Packages & Registries
Packages & Registries
Package Registry
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issue
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
monticore
EmbeddedMontiArc
generators
CNNArch2Gluon
Commits
d03d1284
Commit
d03d1284
authored
Apr 14, 2020
by
Evgeny Kusmenko
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'develop' into 'master'
Develop See merge request
!28
parents
67579b12
06caaddd
Changes
68
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
68 changed files
with
3783 additions
and
458 deletions
+3783
-458
.gitlab-ci.yml
.gitlab-ci.yml
+12
-12
hs_err_pid1468.log
hs_err_pid1468.log
+1015
-0
pom.xml
pom.xml
+5
-5
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2Gluon.java
...e/lang/monticar/cnnarch/gluongenerator/CNNArch2Gluon.java
+0
-2
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonLayerSupportChecker.java
...arch/gluongenerator/CNNArch2GluonLayerSupportChecker.java
+2
-1
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonTemplateController.java
...narch/gluongenerator/CNNArch2GluonTemplateController.java
+3
-0
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java
.../lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java
+35
-18
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/preprocessing/PreprocessingComponentParameterAdapter.java
...preprocessing/PreprocessingComponentParameterAdapter.java
+137
-0
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/preprocessing/PreprocessingPortChecker.java
...luongenerator/preprocessing/PreprocessingPortChecker.java
+60
-0
src/main/resources/templates/gluon/CNNCreator.ftl
src/main/resources/templates/gluon/CNNCreator.ftl
+55
-0
src/main/resources/templates/gluon/CNNDataLoader.ftl
src/main/resources/templates/gluon/CNNDataLoader.ftl
+4
-3
src/main/resources/templates/gluon/CNNGanTrainer.ftl
src/main/resources/templates/gluon/CNNGanTrainer.ftl
+142
-141
src/main/resources/templates/gluon/CNNNet.ftl
src/main/resources/templates/gluon/CNNNet.ftl
+0
-25
src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl
src/main/resources/templates/gluon/CNNSupervisedTrainer.ftl
+93
-9
src/main/resources/templates/gluon/CNNTrainer.ftl
src/main/resources/templates/gluon/CNNTrainer.ftl
+12
-3
src/main/resources/templates/gluon/elements/Crop.ftl
src/main/resources/templates/gluon/elements/Crop.ftl
+11
-0
src/main/resources/templates/gluon/elements/UpConvolution.ftl
...main/resources/templates/gluon/elements/UpConvolution.ftl
+0
-0
src/main/resources/templates/gluon/gan/InputGenerator.ftl
src/main/resources/templates/gluon/gan/InputGenerator.ftl
+30
-14
src/main/resources/templates/gluon/gan/Trainer.ftl
src/main/resources/templates/gluon/gan/Trainer.ftl
+37
-15
src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GenerationTest.java
.../lang/monticar/cnnarch/gluongenerator/GenerationTest.java
+55
-4
src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/preprocessing/PreprocessingParameterCheckerTest.java
...ator/preprocessing/PreprocessingParameterCheckerTest.java
+141
-0
src/test/resources/architectures/CropTest.cnna
src/test/resources/architectures/CropTest.cnna
+9
-0
src/test/resources/target_code/CNNCreator_Alexnet.py
src/test/resources/target_code/CNNCreator_Alexnet.py
+39
-0
src/test/resources/target_code/CNNCreator_CifarClassifierNetwork.py
...esources/target_code/CNNCreator_CifarClassifierNetwork.py
+39
-0
src/test/resources/target_code/CNNCreator_VGG16.py
src/test/resources/target_code/CNNCreator_VGG16.py
+39
-0
src/test/resources/target_code/CNNDataLoader_Alexnet.py
src/test/resources/target_code/CNNDataLoader_Alexnet.py
+4
-3
src/test/resources/target_code/CNNDataLoader_CifarClassifierNetwork.py
...urces/target_code/CNNDataLoader_CifarClassifierNetwork.py
+4
-3
src/test/resources/target_code/CNNDataLoader_VGG16.py
src/test/resources/target_code/CNNDataLoader_VGG16.py
+4
-3
src/test/resources/target_code/CNNNet_Alexnet.py
src/test/resources/target_code/CNNNet_Alexnet.py
+0
-13
src/test/resources/target_code/CNNNet_CifarClassifierNetwork.py
...st/resources/target_code/CNNNet_CifarClassifierNetwork.py
+0
-13
src/test/resources/target_code/CNNNet_VGG16.py
src/test/resources/target_code/CNNNet_VGG16.py
+0
-13
src/test/resources/target_code/CNNSupervisedTrainer_Alexnet.py
...est/resources/target_code/CNNSupervisedTrainer_Alexnet.py
+94
-10
src/test/resources/target_code/CNNSupervisedTrainer_CifarClassifierNetwork.py
...arget_code/CNNSupervisedTrainer_CifarClassifierNetwork.py
+94
-10
src/test/resources/target_code/CNNSupervisedTrainer_Invariant.py
...t/resources/target_code/CNNSupervisedTrainer_Invariant.py
+40
-10
src/test/resources/target_code/CNNSupervisedTrainer_MultipleStreams.py
...urces/target_code/CNNSupervisedTrainer_MultipleStreams.py
+40
-10
src/test/resources/target_code/CNNSupervisedTrainer_RNNencdec.py
...t/resources/target_code/CNNSupervisedTrainer_RNNencdec.py
+40
-10
src/test/resources/target_code/CNNSupervisedTrainer_RNNsearch.py
...t/resources/target_code/CNNSupervisedTrainer_RNNsearch.py
+40
-10
src/test/resources/target_code/CNNSupervisedTrainer_RNNtest.py
...est/resources/target_code/CNNSupervisedTrainer_RNNtest.py
+40
-10
src/test/resources/target_code/CNNSupervisedTrainer_ResNeXt50.py
...t/resources/target_code/CNNSupervisedTrainer_ResNeXt50.py
+40
-10
src/test/resources/target_code/CNNSupervisedTrainer_Show_attend_tell.py
...rces/target_code/CNNSupervisedTrainer_Show_attend_tell.py
+40
-10
src/test/resources/target_code/CNNSupervisedTrainer_ThreeInputCNN_M14.py
...ces/target_code/CNNSupervisedTrainer_ThreeInputCNN_M14.py
+40
-10
src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py
src/test/resources/target_code/CNNSupervisedTrainer_VGG16.py
+94
-10
src/test/resources/target_code/CNNTrainer_emptyConfig.py
src/test/resources/target_code/CNNTrainer_emptyConfig.py
+1
-0
src/test/resources/target_code/CNNTrainer_fullConfig.py
src/test/resources/target_code/CNNTrainer_fullConfig.py
+2
-0
src/test/resources/target_code/CNNTrainer_simpleConfig.py
src/test/resources/target_code/CNNTrainer_simpleConfig.py
+1
-0
src/test/resources/target_code/ddpg/reinforcement_learning/CNNCreator_CriticNetwork.py
...e/ddpg/reinforcement_learning/CNNCreator_CriticNetwork.py
+42
-0
src/test/resources/target_code/ddpg/reinforcement_learning/CNNNet_CriticNetwork.py
..._code/ddpg/reinforcement_learning/CNNNet_CriticNetwork.py
+0
-16
src/test/resources/target_code/default-gan/CNNTrainer_defaultGAN.py
...esources/target_code/default-gan/CNNTrainer_defaultGAN.py
+62
-0
src/test/resources/target_code/default-gan/gan/CNNCreator_Discriminator.py
...s/target_code/default-gan/gan/CNNCreator_Discriminator.py
+99
-0
src/test/resources/target_code/default-gan/gan/CNNNet_Discriminator.py
...urces/target_code/default-gan/gan/CNNNet_Discriminator.py
+176
-0
src/test/resources/target_code/info-gan/CNNTrainer_infoGAN.py
...test/resources/target_code/info-gan/CNNTrainer_infoGAN.py
+65
-0
src/test/resources/target_code/info-gan/gan/CNNCreator_InfoDiscriminator.py
.../target_code/info-gan/gan/CNNCreator_InfoDiscriminator.py
+102
-0
src/test/resources/target_code/info-gan/gan/CNNCreator_InfoQNetwork.py
...urces/target_code/info-gan/gan/CNNCreator_InfoQNetwork.py
+99
-0
src/test/resources/target_code/info-gan/gan/CNNNet_InfoDiscriminator.py
...rces/target_code/info-gan/gan/CNNNet_InfoDiscriminator.py
+177
-0
src/test/resources/target_code/info-gan/gan/CNNNet_InfoQNetwork.py
...resources/target_code/info-gan/gan/CNNNet_InfoQNetwork.py
+124
-0
src/test/resources/target_code/ros-ddpg/reinforcement_learning/CNNCreator_RosCriticNetwork.py
...dpg/reinforcement_learning/CNNCreator_RosCriticNetwork.py
+42
-0
src/test/resources/target_code/ros-ddpg/reinforcement_learning/CNNNet_RosCriticNetwork.py
...os-ddpg/reinforcement_learning/CNNNet_RosCriticNetwork.py
+0
-16
src/test/resources/target_code/td3/reinforcement_learning/CNNCreator_CriticNetwork.py
...de/td3/reinforcement_learning/CNNCreator_CriticNetwork.py
+42
-0
src/test/resources/target_code/td3/reinforcement_learning/CNNNet_CriticNetwork.py
...t_code/td3/reinforcement_learning/CNNNet_CriticNetwork.py
+0
-16
src/test/resources/valid_tests/FullConfig.cnnt
src/test/resources/valid_tests/FullConfig.cnnt
+1
-0
src/test/resources/valid_tests/default-gan/DefaultGAN.cnnt
src/test/resources/valid_tests/default-gan/DefaultGAN.cnnt
+28
-0
src/test/resources/valid_tests/default-gan/arc/DefaultGAN.cnna
...est/resources/valid_tests/default-gan/arc/DefaultGAN.cnna
+22
-0
src/test/resources/valid_tests/default-gan/arc/Discriminator.cnna
.../resources/valid_tests/default-gan/arc/Discriminator.cnna
+20
-0
src/test/resources/valid_tests/info-gan/InfoGAN.cnnt
src/test/resources/valid_tests/info-gan/InfoGAN.cnnt
+29
-0
src/test/resources/valid_tests/info-gan/arc/InfoDiscriminator.cnna
...resources/valid_tests/info-gan/arc/InfoDiscriminator.cnna
+25
-0
src/test/resources/valid_tests/info-gan/arc/InfoGAN.cnna
src/test/resources/valid_tests/info-gan/arc/InfoGAN.cnna
+22
-0
src/test/resources/valid_tests/info-gan/arc/InfoQNetwork.cnna
...test/resources/valid_tests/info-gan/arc/InfoQNetwork.cnna
+12
-0
src/test/resources/valid_tests/weights_paths.txt
src/test/resources/valid_tests/weights_paths.txt
+1
-0
No files found.
.gitlab-ci.yml
View file @
d03d1284
# (c) https://github.com/MontiCore/monticore
stages
:
-
windows
#
- windows
-
linux
masterJobLinux
:
...
...
@@ -20,17 +20,17 @@ masterJobLinux:
-
.gitlab-ci.yml
masterJobWindows
:
stage
:
windows
script
:
-
mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
tags
:
-
Windows10
except
:
changes
:
-
README.md
-
.gitignore
-
.gitlab-ci.yml
#
masterJobWindows:
#
stage: windows
#
script:
#
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
#
tags:
#
- Windows10
#
except:
#
changes:
#
- README.md
#
- .gitignore
#
- .gitlab-ci.yml
BranchJobLinux
:
...
...
hs_err_pid1468.log
0 → 100644
View file @
d03d1284
This diff is collapsed.
Click to expand it.
pom.xml
View file @
d03d1284
...
...
@@ -17,12 +17,12 @@
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>
0.3.
4
-SNAPSHOT
</CNNArch.version>
<CNNTrain.version>
0.3.
9
-SNAPSHOT
</CNNTrain.version>
<CNNArch2X.version>
0.0.
5
-SNAPSHOT
</CNNArch2X.version>
<CNNArch.version>
0.3.
5
-SNAPSHOT
</CNNArch.version>
<CNNTrain.version>
0.3.
10
-SNAPSHOT
</CNNTrain.version>
<CNNArch2X.version>
0.0.
6
-SNAPSHOT
</CNNArch2X.version>
<embedded-montiarc-math-opt-generator>
0.1.6
</embedded-montiarc-math-opt-generator>
<EMADL2PythonWrapper.version>
0.0.2-SNAPSHOT
</EMADL2PythonWrapper.version>
<!-- .. Libraries .................................................. -->
<guava.version>
18.0
</guava.version>
<junit.version>
4.12
</junit.version>
...
...
@@ -144,7 +144,7 @@
</dependency>
</dependencies>
<!-- == PROJECT BUILD SETTINGS =========================================== -->
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2Gluon.java
View file @
d03d1284
...
...
@@ -84,8 +84,6 @@ public class CNNArch2Gluon extends CNNArchGenerator {
CNNArch2GluonTemplateController
archTc
=
new
CNNArch2GluonTemplateController
(
architecture
,
templateConfiguration
);
archTc
.
getStreamOutputDomains
(
archTc
.
getArchitecture
().
getStreams
().
get
(
0
));
fileContentMap
.
putAll
(
compilePythonFiles
(
archTc
,
architecture
));
fileContentMap
.
putAll
(
compileCppFiles
(
archTc
));
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonLayerSupportChecker.java
View file @
d03d1284
...
...
@@ -9,7 +9,7 @@ public class CNNArch2GluonLayerSupportChecker extends LayerSupportChecker {
public
CNNArch2GluonLayerSupportChecker
()
{
supportedLayerList
.
add
(
AllPredefinedLayers
.
FULLY_CONNECTED_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
CONVOLUTION_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
TRANS_CONV
_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
UP_CONVOLUTION
_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
SOFTMAX_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
SIGMOID_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
TANH_NAME
);
...
...
@@ -40,6 +40,7 @@ public class CNNArch2GluonLayerSupportChecker extends LayerSupportChecker {
supportedLayerList
.
add
(
AllPredefinedLayers
.
REDUCE_SUM_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
BROADCAST_ADD_NAME
);
supportedLayerList
.
add
(
AllPredefinedLayers
.
RESHAPE_NAME
);
// supportedLayerList.add(AllPredefinedLayers.CROP_NAME);
}
}
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonTemplateController.java
View file @
d03d1284
...
...
@@ -415,6 +415,9 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
dimensions
.
add
(
intDimension
.
toString
());
}
if
(
dimensions
.
isEmpty
())
dimensions
.
add
(
"unknown"
);
String
name
=
getName
(
element
);
if
(
outputAsArray
&&
element
.
isOutput
()
&&
element
instanceof
VariableSymbol
)
{
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java
View file @
d03d1284
...
...
@@ -2,9 +2,12 @@
package
de.monticore.lang.monticar.cnnarch.gluongenerator
;
import
com.google.common.collect.Maps
;
import
de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.cncModel.EMAComponentSymbol
;
import
de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol
;
import
de.monticore.lang.monticar.cnnarch.gluongenerator.annotations.ArchitectureAdapter
;
import
de.monticore.lang.monticar.cnnarch.gluongenerator.preprocessing.PreprocessingComponentParameterAdapter
;
import
de.monticore.lang.monticar.cnnarch.gluongenerator.preprocessing.PreprocessingPortChecker
;
import
de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.FunctionParameterChecker
;
import
de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionParameterAdapter
;
import
de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator
;
...
...
@@ -78,9 +81,9 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
try
{
Iterator
var6
=
fileContents
.
keySet
().
iterator
();
while
(
var6
.
hasNext
())
{
String
fileName
=
(
String
)
var6
.
next
();
genCPP
.
generateFile
(
new
FileContent
((
String
)
fileContents
.
get
(
fileName
),
fileName
));
while
(
var6
.
hasNext
())
{
String
fileName
=
(
String
)
var6
.
next
();
genCPP
.
generateFile
(
new
FileContent
((
String
)
fileContents
.
get
(
fileName
),
fileName
));
}
}
catch
(
IOException
var8
)
{
Log
.
error
(
"CNNTrainer file could not be generated"
+
var8
.
getMessage
());
...
...
@@ -98,6 +101,19 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
generateFilesFromConfigurationSymbol
(
configurationSymbol
);
}
public
void
generate
(
Path
modelsDirPath
,
String
rootModelName
,
NNArchitectureSymbol
trainedArchitecture
,
NNArchitectureSymbol
discriminatorNetwork
,
NNArchitectureSymbol
qNetwork
)
{
ConfigurationSymbol
configurationSymbol
=
this
.
getConfigurationSymbol
(
modelsDirPath
,
rootModelName
);
configurationSymbol
.
setTrainedArchitecture
(
trainedArchitecture
);
configurationSymbol
.
setDiscriminatorNetwork
(
discriminatorNetwork
);
configurationSymbol
.
setQNetwork
(
qNetwork
);
this
.
setRootProjectModelsDir
(
modelsDirPath
.
toString
());
generateFilesFromConfigurationSymbol
(
configurationSymbol
);
}
public
void
generate
(
Path
modelsDirPath
,
String
rootModelName
,
NNArchitectureSymbol
trainedArchitecture
)
{
generate
(
modelsDirPath
,
rootModelName
,
trainedArchitecture
,
null
);
}
...
...
@@ -117,16 +133,16 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
if
(
configData
.
isSupervisedLearning
())
{
String
cnnTrainTemplateContent
=
templateConfiguration
.
processTemplate
(
ftlContext
,
"CNNTrainer.ftl"
);
fileContentMap
.
put
(
"CNNTrainer_"
+
getInstanceName
()
+
".py"
,
cnnTrainTemplateContent
);
}
else
if
(
configData
.
isGan
())
{
}
else
if
(
configData
.
isGan
())
{
final
String
trainerName
=
"CNNTrainer_"
+
getInstanceName
();
if
(!
configuration
.
getDiscriminatorNetwork
().
isPresent
())
{
if
(!
configuration
.
getDiscriminatorNetwork
().
isPresent
())
{
Log
.
error
(
"No architecture model for discriminator available but is required for chosen "
+
"GAN"
);
}
NNArchitectureSymbol
genericDisArchitectureSymbol
=
configuration
.
getDiscriminatorNetwork
().
get
();
ArchitectureSymbol
disArchitectureSymbol
=
((
ArchitectureAdapter
)
genericDisArchitectureSymbol
).
getArchitectureSymbol
();
=
((
ArchitectureAdapter
)
genericDisArchitectureSymbol
).
getArchitectureSymbol
();
CNNArch2Gluon
gluonGenerator
=
new
CNNArch2Gluon
();
gluonGenerator
.
setGenerationTargetPath
(
...
...
@@ -147,7 +163,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
if
(
configuration
.
hasQNetwork
())
{
NNArchitectureSymbol
genericQArchitectureSymbol
=
configuration
.
getQNetwork
().
get
();
ArchitectureSymbol
qArchitectureSymbol
=
((
ArchitectureAdapter
)
genericQArchitectureSymbol
).
getArchitectureSymbol
();
=
((
ArchitectureAdapter
)
genericQArchitectureSymbol
).
getArchitectureSymbol
();
Map
<
String
,
String
>
qArchitectureFileContentMap
=
gluonGenerator
.
generateStringsAllowMultipleIO
(
qArchitectureSymbol
,
true
);
...
...
@@ -179,7 +195,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
final
RLAlgorithm
rlAlgorithm
=
configData
.
getRlAlgorithm
();
if
(
rlAlgorithm
.
equals
(
RLAlgorithm
.
DDPG
)
||
rlAlgorithm
.
equals
(
RLAlgorithm
.
TD3
))
{
||
rlAlgorithm
.
equals
(
RLAlgorithm
.
TD3
))
{
if
(!
configuration
.
getCriticNetwork
().
isPresent
())
{
Log
.
error
(
"No architecture model for critic available but is required for chosen "
+
...
...
@@ -187,7 +203,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
}
NNArchitectureSymbol
genericArchitectureSymbol
=
configuration
.
getCriticNetwork
().
get
();
ArchitectureSymbol
architectureSymbol
=
((
ArchitectureAdapter
)
genericArchitectureSymbol
).
getArchitectureSymbol
();
=
((
ArchitectureAdapter
)
genericArchitectureSymbol
).
getArchitectureSymbol
();
CNNArch2Gluon
gluonGenerator
=
new
CNNArch2Gluon
();
gluonGenerator
.
setGenerationTargetPath
(
...
...
@@ -201,8 +217,8 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
creatorName
.
indexOf
(
'_'
)
+
1
,
creatorName
.
lastIndexOf
(
".py"
));
fileContentMap
.
putAll
(
architectureFileContentMap
.
entrySet
().
stream
().
collect
(
Collectors
.
toMap
(
k
->
REINFORCEMENT_LEARNING_FRAMEWORK_MODULE
+
"/"
+
k
.
getKey
(),
Map
.
Entry
::
getValue
))
k
->
REINFORCEMENT_LEARNING_FRAMEWORK_MODULE
+
"/"
+
k
.
getKey
(),
Map
.
Entry
::
getValue
))
);
ftlContext
.
put
(
"criticInstanceName"
,
criticInstanceName
);
...
...
@@ -215,7 +231,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
configuration
.
getRlRewardFunction
().
get
(),
Paths
.
get
(
rootProjectModelsDir
));
}
else
{
Log
.
error
(
"No architecture model for the trained neural network but is required for "
+
"reinforcement learning configuration."
);
"reinforcement learning configuration."
);
}
}
...
...
@@ -234,7 +250,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
}
private
void
generateRewardFunction
(
NNArchitectureSymbol
trainedArchitecture
,
RewardFunctionSymbol
rewardFunctionSymbol
,
Path
modelsDirPath
)
{
RewardFunctionSymbol
rewardFunctionSymbol
,
Path
modelsDirPath
)
{
GeneratorPythonWrapperStandaloneApi
pythonWrapperApi
=
new
GeneratorPythonWrapperStandaloneApi
();
List
<
String
>
fullNameOfComponent
=
rewardFunctionSymbol
.
getRewardFunctionComponentName
();
...
...
@@ -271,7 +287,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
rewardFunctionSymbol
.
setRewardFunctionParameter
(
functionParameter
);
}
private
void
fixArmadilloEmamGenerationOfFile
(
Path
pathToBrokenFile
){
private
void
fixArmadilloEmamGenerationOfFile
(
Path
pathToBrokenFile
)
{
final
File
brokenFile
=
pathToBrokenFile
.
toFile
();
if
(
brokenFile
.
exists
())
{
try
{
...
...
@@ -301,19 +317,19 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
fileContentMap
.
put
(
REINFORCEMENT_LEARNING_FRAMEWORK_MODULE
+
"/agent.py"
,
reinforcementAgentContent
);
final
String
reinforcementStrategyContent
=
templateConfiguration
.
processTemplate
(
ftlContext
,
"reinforcement/agent/Strategy.ftl"
);
ftlContext
,
"reinforcement/agent/Strategy.ftl"
);
fileContentMap
.
put
(
REINFORCEMENT_LEARNING_FRAMEWORK_MODULE
+
"/strategy.py"
,
reinforcementStrategyContent
);
final
String
replayMemoryContent
=
templateConfiguration
.
processTemplate
(
ftlContext
,
"reinforcement/agent/ReplayMemory.ftl"
);
ftlContext
,
"reinforcement/agent/ReplayMemory.ftl"
);
fileContentMap
.
put
(
REINFORCEMENT_LEARNING_FRAMEWORK_MODULE
+
"/replay_memory.py"
,
replayMemoryContent
);
final
String
environmentContent
=
templateConfiguration
.
processTemplate
(
ftlContext
,
"reinforcement/environment/Environment.ftl"
);
ftlContext
,
"reinforcement/environment/Environment.ftl"
);
fileContentMap
.
put
(
REINFORCEMENT_LEARNING_FRAMEWORK_MODULE
+
"/environment.py"
,
environmentContent
);
final
String
utilContent
=
templateConfiguration
.
processTemplate
(
ftlContext
,
"reinforcement/util/Util.ftl"
);
ftlContext
,
"reinforcement/util/Util.ftl"
);
fileContentMap
.
put
(
REINFORCEMENT_LEARNING_FRAMEWORK_MODULE
+
"/util.py"
,
utilContent
);
final
String
initContent
=
""
;
...
...
@@ -322,3 +338,4 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
return
fileContentMap
;
}
}
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/preprocessing/PreprocessingComponentParameterAdapter.java
0 → 100644
View file @
d03d1284
/* (c) https://github.com/MontiCore/monticore */
package
de.monticore.lang.monticar.cnnarch.gluongenerator.preprocessing
;
import
de.monticore.lang.monticar.cnntrain.annotations.PreprocessingComponentParameter
;
import
de.monticore.lang.monticar.cnntrain.annotations.RewardFunctionParameter
;
import
de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.ComponentPortInformation
;
import
de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.EmadlType
;
import
de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.PortVariable
;
import
java.util.List
;
import
java.util.Optional
;
import
java.util.stream.Collectors
;
/**
*
*/
public
class
PreprocessingComponentParameterAdapter
implements
PreprocessingComponentParameter
{
private
final
ComponentPortInformation
adaptee
;
private
String
outputParameterName
;
private
String
inputStateParameterName
;
private
String
inputTerminalParameterName
;
public
PreprocessingComponentParameterAdapter
(
final
ComponentPortInformation
componentPortInformation
)
{
this
.
adaptee
=
componentPortInformation
;
}
@Override
public
List
<
String
>
getInputNames
()
{
return
this
.
adaptee
.
getAllInputs
().
stream
()
.
map
(
PortVariable:
:
getVariableName
)
.
collect
(
Collectors
.
toList
());
}
@Override
public
List
<
String
>
getOutputNames
()
{
return
this
.
adaptee
.
getAllOutputs
().
stream
()
.
map
(
PortVariable:
:
getVariableName
)
.
collect
(
Collectors
.
toList
());
}
@Override
public
Optional
<
String
>
getTypeOfInputPort
(
String
portName
)
{
return
this
.
adaptee
.
getAllInputs
().
stream
()
.
filter
(
port
->
port
.
getVariableName
().
equals
(
portName
))
.
map
(
port
->
port
.
getEmadlType
().
toString
())
.
findFirst
();
}
@Override
public
Optional
<
String
>
getTypeOfOutputPort
(
String
portName
)
{
return
this
.
adaptee
.
getAllOutputs
().
stream
()
.
filter
(
port
->
port
.
getVariableName
().
equals
(
portName
))
.
map
(
port
->
port
.
getEmadlType
().
toString
())
.
findFirst
();
}
@Override
public
Optional
<
List
<
Integer
>>
getInputPortDimensionOfPort
(
String
portName
)
{
return
this
.
adaptee
.
getAllInputs
().
stream
()
.
filter
(
port
->
port
.
getVariableName
().
equals
(
portName
))
.
map
(
PortVariable:
:
getDimension
)
.
findFirst
();
}
@Override
public
Optional
<
List
<
Integer
>>
getOutputPortDimensionOfPort
(
String
portName
)
{
return
this
.
adaptee
.
getAllOutputs
().
stream
()
.
filter
(
port
->
port
.
getVariableName
().
equals
(
portName
))
.
map
(
PortVariable:
:
getDimension
)
.
findFirst
();
}
public
Optional
<
String
>
getOutputParameterName
()
{
if
(
this
.
outputParameterName
==
null
)
{
if
(
this
.
getOutputNames
().
size
()
==
1
)
{
this
.
outputParameterName
=
this
.
getOutputNames
().
get
(
0
);
}
else
{
return
Optional
.
empty
();
}
}
return
Optional
.
of
(
this
.
outputParameterName
);
}
private
boolean
isBooleanScalar
(
final
PortVariable
portVariable
)
{
return
portVariable
.
getEmadlType
().
equals
(
EmadlType
.
B
)
&&
portVariable
.
getDimension
().
size
()
==
1
&&
portVariable
.
getDimension
().
get
(
0
)
==
1
;
}
private
boolean
determineInputNames
()
{
if
(
this
.
getInputNames
().
size
()
!=
2
)
{
return
false
;
}
Optional
<
String
>
terminalInput
=
this
.
adaptee
.
getAllInputs
()
.
stream
()
.
filter
(
this
::
isBooleanScalar
)
.
map
(
PortVariable:
:
getVariableName
)
.
findFirst
();
if
(
terminalInput
.
isPresent
())
{
this
.
inputTerminalParameterName
=
terminalInput
.
get
();
}
else
{
return
false
;
}
Optional
<
String
>
stateInput
=
this
.
adaptee
.
getAllInputs
().
stream
()
.
filter
(
portVariable
->
!
portVariable
.
getVariableName
().
equals
(
this
.
inputTerminalParameterName
))
.
filter
(
portVariable
->
!
isBooleanScalar
(
portVariable
))
.
map
(
PortVariable:
:
getVariableName
)
.
findFirst
();
if
(
stateInput
.
isPresent
())
{
this
.
inputStateParameterName
=
stateInput
.
get
();
}
else
{
this
.
inputTerminalParameterName
=
null
;
return
false
;
}
return
true
;
}
public
Optional
<
String
>
getInputStateParameterName
()
{
if
(
this
.
inputStateParameterName
==
null
)
{
this
.
determineInputNames
();
}
return
Optional
.
ofNullable
(
this
.
inputStateParameterName
);
}
public
Optional
<
String
>
getInputTerminalParameter
()
{
if
(
this
.
inputTerminalParameterName
==
null
)
{
this
.
determineInputNames
();
}
return
Optional
.
ofNullable
(
this
.
inputTerminalParameterName
);
}
}
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/preprocessing/PreprocessingPortChecker.java
0 → 100644
View file @
d03d1284
/* (c) https://github.com/MontiCore/monticore */
package
de.monticore.lang.monticar.cnnarch.gluongenerator.preprocessing
;
import
de.se_rwth.commons.logging.Log
;
import
java.util.HashSet
;
import
java.util.ListIterator
;
import
java.util.Set
;
/**
*
*/
public
class
PreprocessingPortChecker
{
public
PreprocessingPortChecker
()
{
}
static
public
void
check
(
final
PreprocessingComponentParameterAdapter
preprocessingComponentParameter
)
{
assert
preprocessingComponentParameter
!=
null
;
checkEqualNumberofInAndOutPorts
(
preprocessingComponentParameter
);
checkCorrectPortNames
(
preprocessingComponentParameter
);
}
static
private
void
checkEqualNumberofInAndOutPorts
(
PreprocessingComponentParameterAdapter
preprocessingComponentParameter
)
{
failIfConditionFails
(
equalNumberOfInAndOutPorts
(
preprocessingComponentParameter
),
"The number of in- and output ports of the "
+
"preprocessing component is not equal"
);
}
static
private
boolean
equalNumberOfInAndOutPorts
(
PreprocessingComponentParameterAdapter
preprocessingComponentParameter
)
{
return
preprocessingComponentParameter
.
getInputNames
().
size
()
==
preprocessingComponentParameter
.
getOutputNames
().
size
();
}
static
private
void
checkCorrectPortNames
(
PreprocessingComponentParameterAdapter
preprocessingComponentParameter
)
{
failIfConditionFails
(
correctPortNames
(
preprocessingComponentParameter
),
"The output ports are not correctly named with \"_out\" appendix"
);
}
static
private
boolean
correctPortNames
(
PreprocessingComponentParameterAdapter
preprocessingComponentParameter
)
{
ListIterator
<
String
>
iterator
=
preprocessingComponentParameter
.
getInputNames
().
listIterator
();
Set
<
String
>
inputs
=
new
HashSet
<
String
>();
while
(
iterator
.
hasNext
())
{
inputs
.
add
(
iterator
.
next
()
+
"_out"
);
}
Set
<
String
>
outputs
=
new
HashSet
<
String
>(
preprocessingComponentParameter
.
getOutputNames
());
return
inputs
.
equals
(
outputs
);
}
static
private
void
failIfConditionFails
(
final
boolean
condition
,
final
String
message
)
{
if
(!
condition
)
{
fail
(
message
);
}
}
static
private
void
fail
(
final
String
message
)
{
Log
.
error
(
message
);
//System.exit(-1);
}
}
src/main/resources/templates/gluon/CNNCreator.ftl
View file @
d03d1284
...
...
@@ -2,6 +2,7 @@
import mxnet as mx
import logging
import os
import shutil
<#list tc.architecture.networkInstructions as networkInstruction>
from CNNNet_${tc.fullArchitectureName} import Net_${networkInstruction?index}
...
...
@@ -14,6 +15,11 @@ class ${tc.fileNameWithoutEnding}:
def __init__(self):
self.weight_initializer = mx.init.Normal()
self.networks = {}
<#if (tc.weightsPath)??>
self._weights_dir_ = "${tc.weightsPath}/"
<#else>
self._weights_dir_ = None
</#if>
def load(self, context):
earliestLastEpoch = None
...
...
@@ -50,6 +56,29 @@ class ${tc.fileNameWithoutEnding}:
return earliestLastEpoch
def load_pretrained_weights(self, context):
if os.path.isdir(self._model_dir_):
shutil.rmtree(self._model_dir_)
if self._weights_dir_ is not None:
for i, network in self.networks.items():
# param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
param_file = None
if os.path.isdir(self._weights_dir_):
lastEpoch = 0
for file in os.listdir(self._weights_dir_):
if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
if epoch > lastEpoch:
lastEpoch = epoch
param_file = file
logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
else:
logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
def construct(self, context, data_mean=None, data_std=None):
<#list tc.architecture.networkInstructions as networkInstruction>
self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}(data_mean=data_mean, data_std=data_std)
...
...
@@ -63,3 +92,29 @@ class ${tc.fileNameWithoutEnding}:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def getInputs(self):
inputs = {}
<#list tc.architecture.streams as stream>
<#assign dimensions = (tc.getStreamInputs(stream, false))>
<#assign domains = (tc.getStreamInputDomains(stream))>
<#list tc.getStreamInputVariableNames(stream, false) as name>
input_dimensions = (${tc.join(dimensions[name], ",")},)
input_domains = (${tc.join(domains[name], ",")},)
inputs["${name}"] = input_domains + (input_dimensions,)
</#list>
</#list>
return inputs
def getOutputs(self):
outputs = {}
<#list tc.architecture.streams as stream>
<#assign dimensions = (tc.getStreamOutputs(stream, false))>
<#assign domains = (tc.getStreamOutputDomains(stream))>
<#list tc.getStreamOutputVariableNames(stream, false) as name>
output_dimensions = (${tc.join(dimensions[name], ",")},)
output_domains = (${tc.join(domains[name], ",")},)
outputs["${name}"] = output_domains + (output_dimensions,)
</#list>
</#list>
return outputs
src/main/resources/templates/gluon/CNNDataLoader.ftl
View file @
d03d1284
...
...
@@ -5,7 +5,6 @@ import mxnet as mx
import logging
import sys
import numpy as np
import cv2
import importlib
from mxnet import nd
...
...
@@ -79,6 +78,7 @@ class ${tc.fileNameWithoutEnding}:
train_label = {}
data_mean = {}
data_std = {}
train_images = {}
shape_output = self.preprocess_data(instance, inp, 0, train_h5)
train_len = len(train_h5[self._input_names_[0]])
...
...
@@ -141,6 +141,7 @@ class ${tc.fileNameWithoutEnding}:
for output_name in self._output_names_:
test_label[output_name][i] = getattr(shape_output, output_name + "_out")
test_images = {}
if 'images' in test_h5:
test_images = test_h5['images']
...
...
@@ -152,7 +153,7 @@ class ${tc.fileNameWithoutEnding}:
def preprocess_data(self, instance_wrapper, input_wrapper, index, data_h5):
for input_name in self._input_names_:
data = data_h5[input_name][
0
]
data = data_h5[input_name][
index
]
attr = getattr(input_wrapper, input_name)
if (type(data)) == np.ndarray:
data = np.asfortranarray(data).astype(attr.dtype)
...
...
@@ -160,7 +161,7 @@ class ${tc.fileNameWithoutEnding}:
data = type(attr)(data)
setattr(input_wrapper, input_name, data)