Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
C
CNNArch2Caffe2
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Iterations
Merge Requests
0
Merge Requests
0
Requirements
Requirements
List
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Test Cases
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Operations
Operations
Incidents
Environments
Packages & Registries
Packages & Registries
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issue
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
monticore
EmbeddedMontiArc
generators
CNNArch2Caffe2
Commits
7890b779
Commit
7890b779
authored
Sep 10, 2018
by
Carlos Alfredo Yeverino Rodriguez
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Added class CNNTrain2Caffe2 impelemnting corresponding interface
parent
84d1feea
Pipeline
#72613
passed with stages
in 4 minutes and 19 seconds
Changes
17
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
185 additions
and
291 deletions
+185
-291
pom.xml
pom.xml
+3
-3
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/CNNArch2Caffe2.java
...lang/monticar/cnnarch/caffe2generator/CNNArch2Caffe2.java
+8
-35
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/CNNTrain2Caffe2.java
...ang/monticar/cnnarch/caffe2generator/CNNTrain2Caffe2.java
+84
-0
src/main/resources/templates/caffe2/CNNTrainer.ftl
src/main/resources/templates/caffe2/CNNTrainer.ftl
+10
-10
src/test/java/de/monticore/lang/monticar/cnnarch/caffe2generator/AbstractSymtabTest.java
.../monticar/cnnarch/caffe2generator/AbstractSymtabTest.java
+0
-9
src/test/java/de/monticore/lang/monticar/cnnarch/caffe2generator/GenerationTest.java
...lang/monticar/cnnarch/caffe2generator/GenerationTest.java
+13
-112
src/test/resources/target_code/CMakeLists.txt
src/test/resources/target_code/CMakeLists.txt
+1
-1
src/test/resources/target_code/CNNTrainer_emptyConfig.py
src/test/resources/target_code/CNNTrainer_emptyConfig.py
+13
-0
src/test/resources/target_code/CNNTrainer_fullConfig.py
src/test/resources/target_code/CNNTrainer_fullConfig.py
+34
-0
src/test/resources/target_code/CNNTrainer_main.py
src/test/resources/target_code/CNNTrainer_main.py
+0
-39
src/test/resources/target_code/CNNTrainer_mainSimple.py
src/test/resources/target_code/CNNTrainer_mainSimple.py
+0
-27
src/test/resources/target_code/CNNTrainer_simpleConfig.py
src/test/resources/target_code/CNNTrainer_simpleConfig.py
+18
-0
src/test/resources/valid_tests/FullConfig2.cnnt
src/test/resources/valid_tests/FullConfig2.cnnt
+0
-21
src/test/resources/valid_tests/Network1.cnnt
src/test/resources/valid_tests/Network1.cnnt
+0
-13
src/test/resources/valid_tests/Network2.cnnt
src/test/resources/valid_tests/Network2.cnnt
+0
-13
src/test/resources/valid_tests/SimpleConfig.cnnt
src/test/resources/valid_tests/SimpleConfig.cnnt
+1
-1
src/test/resources/valid_tests/SimpleConfig2.cnnt
src/test/resources/valid_tests/SimpleConfig2.cnnt
+0
-7
No files found.
pom.xml
View file @
7890b779
...
...
@@ -8,15 +8,15 @@
<groupId>
de.monticore.lang.monticar
</groupId>
<artifactId>
cnnarch-caffe2-generator
</artifactId>
<version>
0.2.
4
-SNAPSHOT
</version>
<version>
0.2.
5
-SNAPSHOT
</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
<properties>
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>
0.2.
5
-SNAPSHOT
</CNNArch.version>
<CNNTrain.version>
0.2.
4
-SNAPSHOT
</CNNTrain.version>
<CNNArch.version>
0.2.
6
-SNAPSHOT
</CNNArch.version>
<CNNTrain.version>
0.2.
5
-SNAPSHOT
</CNNTrain.version>
<embedded-montiarc-math-generator>
0.0.25-SNAPSHOT
</embedded-montiarc-math-generator>
<!-- .. Libraries .................................................. -->
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/CNNArch2Caffe2.java
View file @
7890b779
...
...
@@ -26,19 +26,19 @@ import de.monticore.lang.monticar.cnnarch._cocos.CNNArchCocos;
import
de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage
;
import
de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol
;
import
de.monticore.lang.monticar.generator.FileContent
;
import
de.monticore.lang.monticar.generator.cmake.CMakeConfig
;
import
de.monticore.lang.monticar.generator.cmake.CMakeFindModule
;
import
de.monticore.lang.monticar.generator.cpp.GeneratorCPP
;
import
de.monticore.symboltable.GlobalScope
;
import
de.monticore.symboltable.Scope
;
import
de.se_rwth.commons.logging.Log
;
import
java.io.File
;
import
java.io.FileWriter
;
import
java.io.IOException
;
import
java.nio.file.Path
;
import
java.util.*
;
import
java.util.HashMap
;
import
java.util.Map
;
import
java.util.Optional
;
public
class
CNNArch2Caffe2
implements
CNNArchGenerator
{
...
...
@@ -87,24 +87,6 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
}
}
@Override
public
Map
<
String
,
String
>
generateTrainer
(
List
<
ConfigurationSymbol
>
configurations
,
List
<
String
>
instanceNames
,
String
mainComponentName
)
{
int
numberOfNetworks
=
configurations
.
size
();
if
(
configurations
.
size
()
!=
instanceNames
.
size
()){
throw
new
IllegalStateException
(
"The number of configurations and the number of instances for generation of the CNNTrainer is not equal. "
+
"This should have been checked previously."
);
}
List
<
ConfigurationData
>
configDataList
=
new
ArrayList
<>();
for
(
int
i
=
0
;
i
<
numberOfNetworks
;
i
++){
configDataList
.
add
(
new
ConfigurationData
(
configurations
.
get
(
i
),
instanceNames
.
get
(
i
)));
}
Map
<
String
,
Object
>
ftlContext
=
Collections
.
singletonMap
(
"configurations"
,
configDataList
);
return
Collections
.
singletonMap
(
"CNNTrainer_"
+
mainComponentName
+
".py"
,
TemplateConfiguration
.
processTemplate
(
ftlContext
,
"CNNTrainer.ftl"
));
}
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
public
Map
<
String
,
String
>
generateStrings
(
ArchitectureSymbol
architecture
){
Map
<
String
,
String
>
fileContentMap
=
new
HashMap
<>();
...
...
@@ -163,19 +145,10 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
}
private
void
generateFromFilecontentsMap
(
Map
<
String
,
String
>
fileContentMap
)
throws
IOException
{
GeneratorCPP
genCPP
=
new
GeneratorCPP
();
genCPP
.
setGenerationTargetPath
(
getGenerationTargetPath
());
for
(
String
fileName
:
fileContentMap
.
keySet
()){
File
f
=
new
File
(
getGenerationTargetPath
()
+
fileName
);
Log
.
info
(
f
.
getName
(),
"FileCreation:"
);
if
(!
f
.
exists
())
{
f
.
getParentFile
().
mkdirs
();
if
(!
f
.
createNewFile
())
{
Log
.
error
(
"File could not be created"
);
}
}
FileWriter
writer
=
new
FileWriter
(
f
);
writer
.
write
(
fileContentMap
.
get
(
fileName
));
writer
.
close
();
genCPP
.
generateFile
(
new
FileContent
(
fileContentMap
.
get
(
fileName
),
fileName
));
}
}
...
...
@@ -186,7 +159,7 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
CMakeConfig
cMakeConfig
=
new
CMakeConfig
(
rootModelName
);
cMakeConfig
.
addModuleDependency
(
new
CMakeFindModule
(
"Armadillo"
,
true
));
cMakeConfig
.
addCMakeCommand
End
(
"set(LIBS ${LIBS} mxnet)"
);
cMakeConfig
.
addCMakeCommand
(
"set(LIBS ${LIBS} mxnet)"
);
Map
<
String
,
String
>
fileContentMap
=
new
HashMap
<>();
for
(
FileContent
fileContent
:
cMakeConfig
.
generateCMakeFiles
()){
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/caffe2generator/CNNTrain2Caffe2.java
0 → 100644
View file @
7890b779
package
de.monticore.lang.monticar.cnnarch.caffe2generator
;
import
de.monticore.io.paths.ModelPath
;
import
de.monticore.lang.monticar.cnntrain.CNNTrainGenerator
;
import
de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos
;
import
de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol
;
import
de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainLanguage
;
import
de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol
;
import
de.monticore.lang.monticar.generator.FileContent
;
import
de.monticore.lang.monticar.generator.cpp.GeneratorCPP
;
import
de.monticore.symboltable.GlobalScope
;
import
de.se_rwth.commons.logging.Log
;
import
java.io.IOException
;
import
java.nio.file.Path
;
import
java.util.*
;
public
class
CNNTrain2Caffe2
implements
CNNTrainGenerator
{
private
String
generationTargetPath
;
private
String
instanceName
;
public
CNNTrain2Caffe2
()
{
setGenerationTargetPath
(
"./target/generated-sources-cnnarch/"
);
}
public
String
getInstanceName
()
{
String
parsedInstanceName
=
this
.
instanceName
.
replace
(
'.'
,
'_'
).
replace
(
'['
,
'_'
).
replace
(
']'
,
'_'
);
parsedInstanceName
=
parsedInstanceName
.
substring
(
0
,
1
).
toLowerCase
()
+
parsedInstanceName
.
substring
(
1
);
return
parsedInstanceName
;
}
public
void
setInstanceName
(
String
instanceName
)
{
this
.
instanceName
=
instanceName
;
}
public
String
getGenerationTargetPath
()
{
if
(
generationTargetPath
.
charAt
(
generationTargetPath
.
length
()
-
1
)
!=
'/'
)
{
this
.
generationTargetPath
=
generationTargetPath
+
"/"
;
}
return
generationTargetPath
;
}
public
void
setGenerationTargetPath
(
String
generationTargetPath
)
{
this
.
generationTargetPath
=
generationTargetPath
;
}
public
ConfigurationSymbol
getConfigurationSymbol
(
Path
modelsDirPath
,
String
rootModelName
)
{
final
ModelPath
mp
=
new
ModelPath
(
modelsDirPath
);
GlobalScope
scope
=
new
GlobalScope
(
mp
,
new
CNNTrainLanguage
());
Optional
<
CNNTrainCompilationUnitSymbol
>
compilationUnit
=
scope
.
resolve
(
rootModelName
,
CNNTrainCompilationUnitSymbol
.
KIND
);
if
(!
compilationUnit
.
isPresent
())
{
Log
.
error
(
"could not resolve training configuration "
+
rootModelName
);
System
.
exit
(
1
);
}
setInstanceName
(
compilationUnit
.
get
().
getFullName
());
CNNTrainCocos
.
checkAll
(
compilationUnit
.
get
());
return
compilationUnit
.
get
().
getConfiguration
();
}
public
void
generate
(
Path
modelsDirPath
,
String
rootModelName
)
{
ConfigurationSymbol
configuration
=
getConfigurationSymbol
(
modelsDirPath
,
rootModelName
);
Map
<
String
,
String
>
fileContents
=
generateStrings
(
configuration
);
GeneratorCPP
genCPP
=
new
GeneratorCPP
();
genCPP
.
setGenerationTargetPath
(
getGenerationTargetPath
());
try
{
for
(
String
fileName
:
fileContents
.
keySet
()){
genCPP
.
generateFile
(
new
FileContent
(
fileContents
.
get
(
fileName
),
fileName
));
}
}
catch
(
IOException
e
)
{
e
.
printStackTrace
();
}
}
public
Map
<
String
,
String
>
generateStrings
(
ConfigurationSymbol
configuration
)
{
ConfigurationData
configData
=
new
ConfigurationData
(
configuration
,
getInstanceName
());
List
<
ConfigurationData
>
configDataList
=
new
ArrayList
<>();
configDataList
.
add
(
configData
);
Map
<
String
,
Object
>
ftlContext
=
Collections
.
singletonMap
(
"configurations"
,
configDataList
);
String
templateContent
=
TemplateConfiguration
.
processTemplate
(
ftlContext
,
"CNNTrainer.ftl"
);
return
Collections
.
singletonMap
(
"CNNTrainer_"
+
getInstanceName
()
+
".py"
,
templateContent
);
}
}
src/main/resources/templates/caffe2/CNNTrainer.ftl
View file @
7890b779
...
...
@@ -7,37 +7,37 @@ import CNNCreator_${config.instanceName}
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
handler = logging.FileHandler("train.log","w", encoding=None, delay="true")
handler = logging.FileHandler("train.log",
"w", encoding=None, delay="true")
logger.addHandler(handler)
<#list configurations as config>
${config.instanceName} = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
${config.instanceName}.train(
<#if (config.batchSize)??>
batch_size
=
${config.batchSize},
batch_size
=
${config.batchSize},
</#if>
<#if (config.numEpoch)??>
num_epoch
=
${config.numEpoch},
num_epoch
=
${config.numEpoch},
</#if>
<#if (config.loadCheckpoint)??>
load_checkpoint
=
${config.loadCheckpoint?string("True","False")},
load_checkpoint
=
${config.loadCheckpoint?string("True","False")},
</#if>
<#if (config.context)??>
context
=
'${config.context}',
context
=
'${config.context}',
</#if>
<#if (config.normalize)??>
normalize
=
${config.normalize?string("True","False")},
normalize
=
${config.normalize?string("True","False")},
</#if>
<#if (config.evalMetric)??>
eval_metric
=
'${config.evalMetric}',
eval_metric
=
'${config.evalMetric}',
</#if>
<#if (config.configuration.optimizer)??>
optimizer
=
'${config.optimizerName}',
optimizer_params
=
{
optimizer
=
'${config.optimizerName}',
optimizer_params
=
{
<#list config.optimizerParams?keys as param>
'${param}': ${config.optimizerParams[param]}<#sep>,
</#list>
}
}
</#if>
)
</#list>
\ No newline at end of file
src/test/java/de/monticore/lang/monticar/cnnarch/caffe2generator/AbstractSymtabTest.java
View file @
7890b779
...
...
@@ -60,15 +60,6 @@ public class AbstractSymtabTest {
return
scope
;
}
/* protected static ASTCNNArchCompilationUnit getAstNode(String modelPath, String model) {
Scope symTab = createSymTab(MODEL_PATH + modelPath);
CNNArchCompilationUnitSymbol comp = symTab.<CNNArchCompilationUnitSymbol> resolve(
model, CNNArchCompilationUnitSymbol.KIND).orElse(null);
assertNotNull("Could not resolve model " + model, comp);
return (ASTCNNArchCompilationUnit) comp.getAstNode().get();
}*/
protected
static
CNNArchCompilationUnitSymbol
getCompilationUnitSymbol
(
String
modelPath
,
String
model
)
{
Scope
symTab
=
createSymTab
(
MODEL_PATH
+
modelPath
);
CNNArchCompilationUnitSymbol
comp
=
symTab
.<
CNNArchCompilationUnitSymbol
>
resolve
(
...
...
src/test/java/de/monticore/lang/monticar/cnnarch/caffe2generator/GenerationTest.java
View file @
7890b779
...
...
@@ -20,19 +20,13 @@
*/
package
de.monticore.lang.monticar.cnnarch.caffe2generator
;
import
de.monticore.io.paths.ModelPath
;
import
de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos
;
import
de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol
;
import
de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainLanguage
;
import
de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol
;
import
de.monticore.symboltable.GlobalScope
;
import
de.se_rwth.commons.logging.Log
;
import
freemarker.template.TemplateException
;
import
org.junit.Before
;
import
org.junit.Test
;
import
java.io.FileWriter
;
import
java.io.IOException
;
import
java.nio.file.Path
;
import
java.nio.file.Paths
;
import
java.util.*
;
...
...
@@ -121,143 +115,50 @@ public class GenerationTest extends AbstractSymtabTest{
assertTrue
(
Log
.
getFindings
().
size
()
==
3
);
}
@Test
public
void
testCNNTrainerGeneration
()
throws
IOException
,
TemplateException
{
Log
.
getFindings
().
clear
();
List
<
ConfigurationSymbol
>
configurations
=
new
ArrayList
<>();
List
<
String
>
instanceNames
=
Arrays
.
asList
(
"main_net1"
,
"main_net2"
);
final
ModelPath
mp
=
new
ModelPath
(
Paths
.
get
(
"src/test/resources/valid_tests"
));
GlobalScope
scope
=
new
GlobalScope
(
mp
,
new
CNNTrainLanguage
());
CNNTrainCompilationUnitSymbol
compilationUnit
=
scope
.<
CNNTrainCompilationUnitSymbol
>
resolve
(
"Network1"
,
CNNTrainCompilationUnitSymbol
.
KIND
).
get
();
CNNTrainCocos
.
checkAll
(
compilationUnit
);
configurations
.
add
(
compilationUnit
.
getConfiguration
());
compilationUnit
=
scope
.<
CNNTrainCompilationUnitSymbol
>
resolve
(
"Network2"
,
CNNTrainCompilationUnitSymbol
.
KIND
).
get
();
CNNTrainCocos
.
checkAll
(
compilationUnit
);
configurations
.
add
(
compilationUnit
.
getConfiguration
());
CNNArch2Caffe2
generator
=
new
CNNArch2Caffe2
();
Map
<
String
,
String
>
trainerMap
=
generator
.
generateTrainer
(
configurations
,
instanceNames
,
"main"
);
for
(
String
fileName
:
trainerMap
.
keySet
()){
FileWriter
writer
=
new
FileWriter
(
generator
.
getGenerationTargetPath
()
+
fileName
);
writer
.
write
(
trainerMap
.
get
(
fileName
));
writer
.
close
();
}
assertTrue
(
Log
.
getFindings
().
isEmpty
());
checkFilesAreEqual
(
Paths
.
get
(
"./target/generated-sources-cnnarch"
),
Paths
.
get
(
"./src/test/resources/target_code"
),
Arrays
.
asList
(
"CNNTrainer_main.py"
));
}
@Test
public
void
testFullCfgGeneration
()
throws
IOException
,
TemplateException
{
Log
.
getFindings
().
clear
();
List
<
ConfigurationSymbol
>
configurations
=
new
ArrayList
<>();
List
<
String
>
instanceName
=
Arrays
.
asList
(
"main_net1"
,
"main_net2"
);
final
ModelPath
mp
=
new
ModelPath
(
Paths
.
get
(
"src/test/resources/valid_tests"
));
GlobalScope
scope
=
new
GlobalScope
(
mp
,
new
CNNTrainLanguage
());
CNNTrainCompilationUnitSymbol
compilationUnit
=
scope
.<
CNNTrainCompilationUnitSymbol
>
resolve
(
"FullConfig"
,
CNNTrainCompilationUnitSymbol
.
KIND
).
get
();
CNNTrainCocos
.
checkAll
(
compilationUnit
);
configurations
.
add
(
compilationUnit
.
getConfiguration
());
compilationUnit
=
scope
.<
CNNTrainCompilationUnitSymbol
>
resolve
(
"FullConfig2"
,
CNNTrainCompilationUnitSymbol
.
KIND
).
get
();
CNNTrainCocos
.
checkAll
(
compilationUnit
);
configurations
.
add
(
compilationUnit
.
getConfiguration
());
CNNArch2Caffe2
generator
=
new
CNNArch2Caffe2
();
Map
<
String
,
String
>
trainerMap
=
generator
.
generateTrainer
(
configurations
,
instanceName
,
"mainFull"
);
for
(
String
fileName
:
trainerMap
.
keySet
()){
FileWriter
writer
=
new
FileWriter
(
generator
.
getGenerationTargetPath
()
+
fileName
);
writer
.
write
(
trainerMap
.
get
(
fileName
));
writer
.
close
();
}
String
sourcePath
=
"src/test/resources/valid_tests"
;
CNNTrain2Caffe2
trainGenerator
=
new
CNNTrain2Caffe2
();
trainGenerator
.
generate
(
Paths
.
get
(
sourcePath
),
"FullConfig"
);
assertTrue
(
Log
.
getFindings
().
isEmpty
());
checkFilesAreEqual
(
Paths
.
get
(
"./target/generated-sources-cnnarch"
),
Paths
.
get
(
"./src/test/resources/target_code"
),
Arrays
.
asList
(
"CNNTrainer_
mainFull
.py"
));
"CNNTrainer_
fullConfig
.py"
));
}
@Test
public
void
testSimpleCfgGeneration
()
throws
IOException
{
Log
.
getFindings
().
clear
();
List
<
ConfigurationSymbol
>
configurations
=
new
ArrayList
<>();
List
<
String
>
instanceName
=
Arrays
.
asList
(
"main_net1"
,
"main_net2"
);
final
ModelPath
mp
=
new
ModelPath
(
Paths
.
get
(
"src/test/resources/valid_tests"
));
GlobalScope
scope
=
new
GlobalScope
(
mp
,
new
CNNTrainLanguage
());
CNNTrainCompilationUnitSymbol
compilationUnit
=
scope
.<
CNNTrainCompilationUnitSymbol
>
resolve
(
"SimpleConfig1"
,
CNNTrainCompilationUnitSymbol
.
KIND
).
get
();
CNNTrainCocos
.
checkAll
(
compilationUnit
);
configurations
.
add
(
compilationUnit
.
getConfiguration
());
Path
modelPath
=
Paths
.
get
(
"src/test/resources/valid_tests"
);
CNNTrain2Caffe2
trainGenerator
=
new
CNNTrain2Caffe2
();
compilationUnit
=
scope
.<
CNNTrainCompilationUnitSymbol
>
resolve
(
"SimpleConfig2"
,
CNNTrainCompilationUnitSymbol
.
KIND
).
get
();
CNNTrainCocos
.
checkAll
(
compilationUnit
);
configurations
.
add
(
compilationUnit
.
getConfiguration
());
CNNArch2Caffe2
generator
=
new
CNNArch2Caffe2
();
Map
<
String
,
String
>
trainerMap
=
generator
.
generateTrainer
(
configurations
,
instanceName
,
"mainSimple"
);
for
(
String
fileName
:
trainerMap
.
keySet
()){
FileWriter
writer
=
new
FileWriter
(
generator
.
getGenerationTargetPath
()
+
fileName
);
writer
.
write
(
trainerMap
.
get
(
fileName
));
writer
.
close
();
}
trainGenerator
.
generate
(
modelPath
,
"SimpleConfig"
);
assertTrue
(
Log
.
getFindings
().
isEmpty
());
checkFilesAreEqual
(
Paths
.
get
(
"./target/generated-sources-cnnarch"
),
Paths
.
get
(
"./src/test/resources/target_code"
),
Arrays
.
asList
(
"CNNTrainer_
mainSimple
.py"
));
"CNNTrainer_
simpleConfig
.py"
));
}
@Test
public
void
testEmptyCfgGeneration
()
throws
IOException
{
Log
.
getFindings
().
clear
();
List
<
ConfigurationSymbol
>
configurations
=
new
ArrayList
<>();
List
<
String
>
instanceName
=
Arrays
.
asList
(
"main_net1"
);
final
ModelPath
mp
=
new
ModelPath
(
Paths
.
get
(
"src/test/resources/valid_tests"
));
GlobalScope
scope
=
new
GlobalScope
(
mp
,
new
CNNTrainLanguage
());
CNNTrainCompilationUnitSymbol
compilationUnit
=
scope
.<
CNNTrainCompilationUnitSymbol
>
resolve
(
"EmptyConfig"
,
CNNTrainCompilationUnitSymbol
.
KIND
).
get
();
CNNTrainCocos
.
checkAll
(
compilationUnit
);
configurations
.
add
(
compilationUnit
.
getConfiguration
());
CNNArch2Caffe2
generator
=
new
CNNArch2Caffe2
();
Map
<
String
,
String
>
trainerMap
=
generator
.
generateTrainer
(
configurations
,
instanceName
,
"mainEmpty"
);
for
(
String
fileName
:
trainerMap
.
keySet
()){
FileWriter
writer
=
new
FileWriter
(
generator
.
getGenerationTargetPath
()
+
fileName
);
writer
.
write
(
trainerMap
.
get
(
fileName
));
writer
.
close
();
}
Path
modelPath
=
Paths
.
get
(
"src/test/resources/valid_tests"
);
CNNTrain2Caffe2
trainGenerator
=
new
CNNTrain2Caffe2
();
trainGenerator
.
generate
(
modelPath
,
"EmptyConfig"
);
assertTrue
(
Log
.
getFindings
().
isEmpty
());
checkFilesAreEqual
(
Paths
.
get
(
"./target/generated-sources-cnnarch"
),
Paths
.
get
(
"./src/test/resources/target_code"
),
Arrays
.
asList
(
"CNNTrainer_
mainEmpty
.py"
));
"CNNTrainer_
emptyConfig
.py"
));
}
...
...
src/test/resources/target_code/CMakeLists.txt
View file @
7890b779
...
...
@@ -12,6 +12,7 @@ set(INCLUDE_DIRS ${INCLUDE_DIRS} ${Armadillo_INCLUDE_DIRS})
set
(
LIBS
${
LIBS
}
${
Armadillo_LIBRARIES
}
)
# additional commands
set
(
LIBS
${
LIBS
}
mxnet
)
# create static library
include_directories
(
${
INCLUDE_DIRS
}
)
...
...
@@ -24,4 +25,3 @@ set_target_properties(alexnet PROPERTIES LINKER_LANGUAGE CXX)
export
(
TARGETS alexnet FILE alexnet.cmake
)
# additional commands end
set
(
LIBS
${
LIBS
}
mxnet
)
src/test/resources/target_code/CNNTrainer_
mainEmpty
.py
→
src/test/resources/target_code/CNNTrainer_
emptyConfig
.py
View file @
7890b779
import
logging
import
mxnet
as
mx
import
CNNCreator_
main_net1
import
CNNCreator_
emptyConfig
if
__name__
==
"__main__"
:
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
logger
=
logging
.
getLogger
()
handler
=
logging
.
FileHandler
(
"train.log"
,
"w"
,
encoding
=
None
,
delay
=
"true"
)
handler
=
logging
.
FileHandler
(
"train.log"
,
"w"
,
encoding
=
None
,
delay
=
"true"
)
logger
.
addHandler
(
handler
)
main_net1
=
CNNCreator_main_net1
.
CNNCreator_main_net1
()
main_net1
.
train
(
emptyConfig
=
CNNCreator_emptyConfig
.
CNNCreator_emptyConfig
()
emptyConfig
.
train
(
)
src/test/resources/target_code/CNNTrainer_
mainFull
.py
→
src/test/resources/target_code/CNNTrainer_
fullConfig
.py
View file @
7890b779
import
logging
import
mxnet
as
mx
import
CNNCreator_main_net1
import
CNNCreator_main_net2
import
CNNCreator_fullConfig
if
__name__
==
"__main__"
:
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
logger
=
logging
.
getLogger
()
handler
=
logging
.
FileHandler
(
"train.log"
,
"w"
,
encoding
=
None
,
delay
=
"true"
)
handler
=
logging
.
FileHandler
(
"train.log"
,
"w"
,
encoding
=
None
,
delay
=
"true"
)
logger
.
addHandler
(
handler
)
main_net1
=
CNNCreator_main_net1
.
CNNCreator_main_net1
()
main_net1
.
train
(
batch_size
=
100
,
num_epoch
=
5
,
load_checkpoint
=
True
,
context
=
'gpu'
,
normalize
=
True
,
eval_metric
=
'mse'
,
optimizer
=
'rmsprop'
,
optimizer_params
=
{
fullConfig
=
CNNCreator_fullConfig
.
CNNCreator_fullConfig
()
fullConfig
.
train
(
batch_size
=
100
,
num_epoch
=
5
,
load_checkpoint
=
True
,
context
=
'gpu'
,
normalize
=
True
,
eval_metric
=
'mse'
,
optimizer
=
'rmsprop'
,
optimizer_params
=
{
'weight_decay'
:
0.01
,
'centered'
:
True
,
'gamma2'
:
0.9
,
...
...
@@ -31,27 +30,5 @@ if __name__ == "__main__":
'learning_rate_minimum'
:
1.0E-5
,
'learning_rate_policy'
:
'step'
,
'learning_rate'
:
0.001
,
'step_size'
:
1000
}
)
main_net2
=
CNNCreator_main_net2
.
CNNCreator_main_net2
()
main_net2
.
train
(
batch_size
=
100
,
num_epoch
=
10
,
load_checkpoint
=
False
,
context
=
'gpu'
,
normalize
=
False
,
eval_metric
=
'topKAccuracy'
,
optimizer
=
'adam'
,
optimizer_params
=
{
'epsilon'
:
1.0E-6
,
'weight_decay'
:
0.01
,
'rescale_grad'
:
1.1
,
'beta1'
:
0.9
,
'clip_gradient'
:
10.0
,
'beta2'
:
0.9
,
'learning_rate_minimum'
:
0.001
,
'learning_rate_policy'
:
'exp'
,
'learning_rate'
:
0.001
,
'learning_rate_decay'
:
0.9
,
'step_size'
:
1000
}
'step_size'
:
1000
}
)
src/test/resources/target_code/CNNTrainer_main.py
deleted
100644 → 0
View file @
84d1feea
import
logging
import
mxnet
as
mx
import
CNNCreator_main_net1
import
CNNCreator_main_net2
if
__name__
==
"__main__"
:
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
logger
=
logging
.
getLogger
()
handler
=
logging
.
FileHandler
(
"train.log"
,
"w"
,
encoding
=
None
,
delay
=
"true"
)
logger
.
addHandler
(
handler
)
main_net1
=
CNNCreator_main_net1
.
CNNCreator_main_net1
()
main_net1
.
train
(
batch_size
=
64
,
num_epoch
=
10
,
load_checkpoint
=
False
,
context
=
'gpu'
,
normalize
=
True
,
optimizer
=
'adam'
,
optimizer_params
=
{
'weight_decay'
:
1.0E-4
,
'learning_rate'
:
0.01
,
'learning_rate_decay'
:
0.8
,
'step_size'
:
1000
}
)
main_net2
=
CNNCreator_main_net2
.
CNNCreator_main_net2
()
main_net2
.
train
(
batch_size
=
32
,
num_epoch
=
10
,
load_checkpoint
=
False
,
context
=
'gpu'
,
normalize
=
True
,
optimizer
=
'adam'
,
optimizer_params
=
{
'weight_decay'
:
1.0E-4
,
'learning_rate'
:
0.01
,
'learning_rate_decay'
:
0.8
,
'step_size'
:
1000
}
)
src/test/resources/target_code/CNNTrainer_mainSimple.py
deleted
100644 → 0
View file @
84d1feea