Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
monticore
EmbeddedMontiArc
generators
CNNArch2Gluon
Commits
f9a53bef
Commit
f9a53bef
authored
Jan 29, 2019
by
Evgeny Kusmenko
Browse files
Merge branch 'move-to-gluon' into 'master'
Move to gluon Closes
#1
See merge request
!1
parents
4afa996b
842fbaff
Pipeline
#100989
passed with stages
in 4 minutes and 24 seconds
Changes
62
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
pom.xml
View file @
f9a53bef
...
...
@@ -7,8 +7,8 @@
<!-- == PROJECT COORDINATES ============================================= -->
<groupId>
de.monticore.lang.monticar
</groupId>
<artifactId>
cnnarch-
mxnet
-generator
</artifactId>
<version>
0.
2.8
</version>
<artifactId>
cnnarch-
gluon
-generator
</artifactId>
<version>
0.
1.0-SNAPSHOT
</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
...
...
@@ -172,7 +172,7 @@
<configuration>
<archive>
<manifest>
<mainClass>
de.monticore.lang.monticar.cnnarch.
mxnet
generator.CNNArch2
MxNet
Cli
</mainClass>
<mainClass>
de.monticore.lang.monticar.cnnarch.
gluon
generator.CNNArch2
Gluon
Cli
</mainClass>
</manifest>
</archive>
<descriptorRefs>
...
...
@@ -229,7 +229,8 @@
<maxmem>
256m
</maxmem>
<!-- aggregated reports for multi-module projects -->
<aggregate>
true
</aggregate>
</configuration>
<check/>
</configuration>
</plugin>
</plugins>
</build>
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/
mxnet
generator/ArchitectureElementData.java
→
src/main/java/de/monticore/lang/monticar/cnnarch/
gluon
generator/ArchitectureElementData.java
View file @
f9a53bef
...
...
@@ -18,7 +18,7 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package
de.monticore.lang.monticar.cnnarch.
mxnet
generator
;
package
de.monticore.lang.monticar.cnnarch.
gluon
generator
;
import
de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeSymbol
;
import
de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol
;
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/
mxnet
generator/CNNArch2
MxNet
.java
→
src/main/java/de/monticore/lang/monticar/cnnarch/
gluon
generator/CNNArch2
Gluon
.java
View file @
f9a53bef
...
...
@@ -18,7 +18,7 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package
de.monticore.lang.monticar.cnnarch.
mxnet
generator
;
package
de.monticore.lang.monticar.cnnarch.
gluon
generator
;
import
de.monticore.io.paths.ModelPath
;
import
de.monticore.lang.monticar.cnnarch.CNNArchGenerator
;
...
...
@@ -40,11 +40,11 @@ import java.util.HashMap;
import
java.util.Map
;
import
java.util.Optional
;
public
class
CNNArch2
MxNet
implements
CNNArchGenerator
{
public
class
CNNArch2
Gluon
implements
CNNArchGenerator
{
private
String
generationTargetPath
;
public
CNNArch2
MxNet
()
{
public
CNNArch2
Gluon
()
{
setGenerationTargetPath
(
"./target/generated-sources-cnnarch/"
);
}
...
...
@@ -96,6 +96,9 @@ public class CNNArch2MxNet implements CNNArchGenerator {
temp
=
archTc
.
process
(
"CNNPredictor"
,
Target
.
CPP
);
fileContentMap
.
put
(
temp
.
getKey
(),
temp
.
getValue
());
temp
=
archTc
.
process
(
"CNNNet"
,
Target
.
PYTHON
);
fileContentMap
.
put
(
temp
.
getKey
(),
temp
.
getValue
());
temp
=
archTc
.
process
(
"CNNCreator"
,
Target
.
PYTHON
);
fileContentMap
.
put
(
temp
.
getKey
(),
temp
.
getValue
());
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/
mxnet
generator/CNNArch2
MxNet
Cli.java
→
src/main/java/de/monticore/lang/monticar/cnnarch/
gluon
generator/CNNArch2
Gluon
Cli.java
View file @
f9a53bef
...
...
@@ -18,14 +18,14 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package
de.monticore.lang.monticar.cnnarch.
mxnet
generator
;
package
de.monticore.lang.monticar.cnnarch.
gluon
generator
;
import
org.apache.commons.cli.*
;
import
java.nio.file.Path
;
import
java.nio.file.Paths
;
public
class
CNNArch2
MxNet
Cli
{
public
class
CNNArch2
Gluon
Cli
{
public
static
final
Option
OPTION_MODELS_PATH
=
Option
.
builder
(
"m"
)
.
longOpt
(
"models-dir"
)
...
...
@@ -48,7 +48,7 @@ public class CNNArch2MxNetCli {
.
required
(
false
)
.
build
();
private
CNNArch2
MxNet
Cli
()
{
private
CNNArch2
Gluon
Cli
()
{
}
public
static
void
main
(
String
[]
args
)
{
...
...
@@ -84,7 +84,7 @@ public class CNNArch2MxNetCli {
Path
modelsDirPath
=
Paths
.
get
(
cliArgs
.
getOptionValue
(
OPTION_MODELS_PATH
.
getOpt
()));
String
rootModelName
=
cliArgs
.
getOptionValue
(
OPTION_ROOT_MODEL
.
getOpt
());
String
outputPath
=
cliArgs
.
getOptionValue
(
OPTION_OUTPUT_PATH
.
getOpt
());
CNNArch2
MxNet
generator
=
new
CNNArch2
MxNet
();
CNNArch2
Gluon
generator
=
new
CNNArch2
Gluon
();
if
(
outputPath
!=
null
){
generator
.
setGenerationTargetPath
(
outputPath
);
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/
mxnet
generator/CNNArchTemplateController.java
→
src/main/java/de/monticore/lang/monticar/cnnarch/
gluon
generator/CNNArchTemplateController.java
View file @
f9a53bef
...
...
@@ -18,7 +18,7 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package
de.monticore.lang.monticar.cnnarch.
mxnet
generator
;
package
de.monticore.lang.monticar.cnnarch.
gluon
generator
;
import
de.monticore.lang.monticar.cnnarch._symboltable.*
;
import
de.monticore.lang.monticar.cnnarch.predefined.Sigmoid
;
...
...
@@ -34,6 +34,7 @@ public class CNNArchTemplateController {
public
static
final
String
TEMPLATE_ELEMENTS_DIR_PATH
=
"elements/"
;
public
static
final
String
TEMPLATE_CONTROLLER_KEY
=
"tc"
;
public
static
final
String
ELEMENT_DATA_KEY
=
"element"
;
public
static
final
String
NET_DEFINITION_MODE_KEY
=
"definition_mode"
;
private
LayerNameCreator
nameManager
;
private
ArchitectureSymbol
architecture
;
...
...
@@ -123,34 +124,43 @@ public class CNNArchTemplateController {
return
list
;
}
public
void
include
(
String
relativePath
,
String
templateWithoutFileEnding
,
Writer
writer
){
public
void
include
(
String
relativePath
,
String
templateWithoutFileEnding
,
Writer
writer
,
NetDefinitionMode
netDefinitionMode
){
String
templatePath
=
relativePath
+
templateWithoutFileEnding
+
FTL_FILE_ENDING
;
Map
<
String
,
Object
>
ftlContext
=
new
HashMap
<>();
ftlContext
.
put
(
TEMPLATE_CONTROLLER_KEY
,
this
);
ftlContext
.
put
(
ELEMENT_DATA_KEY
,
getCurrentElement
());
ftlContext
.
put
(
NET_DEFINITION_MODE_KEY
,
netDefinitionMode
);
TemplateConfiguration
.
processTemplate
(
ftlContext
,
templatePath
,
writer
);
}
public
void
include
(
IOSymbol
ioElement
,
Writer
writer
){
public
void
include
(
String
relativePath
,
String
templateWithoutFileEnding
,
Writer
writer
)
{
String
templatePath
=
relativePath
+
templateWithoutFileEnding
+
FTL_FILE_ENDING
;
Map
<
String
,
Object
>
ftlContext
=
new
HashMap
<>();
ftlContext
.
put
(
TEMPLATE_CONTROLLER_KEY
,
this
);
ftlContext
.
put
(
ELEMENT_DATA_KEY
,
getCurrentElement
());
TemplateConfiguration
.
processTemplate
(
ftlContext
,
templatePath
,
writer
);
}
public
void
include
(
IOSymbol
ioElement
,
Writer
writer
,
NetDefinitionMode
netDefinitionMode
){
ArchitectureElementData
previousElement
=
getCurrentElement
();
setCurrentElement
(
ioElement
);
if
(
ioElement
.
isAtomic
()){
if
(
ioElement
.
isInput
()){
include
(
TEMPLATE_ELEMENTS_DIR_PATH
,
"Input"
,
writer
);
include
(
TEMPLATE_ELEMENTS_DIR_PATH
,
"Input"
,
writer
,
netDefinitionMode
);
}
else
{
include
(
TEMPLATE_ELEMENTS_DIR_PATH
,
"Output"
,
writer
);
include
(
TEMPLATE_ELEMENTS_DIR_PATH
,
"Output"
,
writer
,
netDefinitionMode
);
}
}
else
{
include
(
ioElement
.
getResolvedThis
().
get
(),
writer
);
include
(
ioElement
.
getResolvedThis
().
get
(),
writer
,
netDefinitionMode
);
}
setCurrentElement
(
previousElement
);
}
public
void
include
(
LayerSymbol
layer
,
Writer
writer
){
public
void
include
(
LayerSymbol
layer
,
Writer
writer
,
NetDefinitionMode
netDefinitionMode
){
ArchitectureElementData
previousElement
=
getCurrentElement
();
setCurrentElement
(
layer
);
...
...
@@ -158,44 +168,48 @@ public class CNNArchTemplateController {
ArchitectureElementSymbol
nextElement
=
layer
.
getOutputElement
().
get
();
if
(!
isSoftmaxOutput
(
nextElement
)
&&
!
isLogisticRegressionOutput
(
nextElement
)){
String
templateName
=
layer
.
getDeclaration
().
getName
();
include
(
TEMPLATE_ELEMENTS_DIR_PATH
,
templateName
,
writer
);
include
(
TEMPLATE_ELEMENTS_DIR_PATH
,
templateName
,
writer
,
netDefinitionMode
);
}
}
else
{
include
(
layer
.
getResolvedThis
().
get
(),
writer
);
include
(
layer
.
getResolvedThis
().
get
(),
writer
,
netDefinitionMode
);
}
setCurrentElement
(
previousElement
);
}
public
void
include
(
CompositeElementSymbol
compositeElement
,
Writer
writer
){
public
void
include
(
CompositeElementSymbol
compositeElement
,
Writer
writer
,
NetDefinitionMode
netDefinitionMode
){
ArchitectureElementData
previousElement
=
getCurrentElement
();
setCurrentElement
(
compositeElement
);
for
(
ArchitectureElementSymbol
element
:
compositeElement
.
getElements
()){
include
(
element
,
writer
);
include
(
element
,
writer
,
netDefinitionMode
);
}
setCurrentElement
(
previousElement
);
}
public
void
include
(
ArchitectureElementSymbol
architectureElement
,
Writer
writer
){
public
void
include
(
ArchitectureElementSymbol
architectureElement
,
Writer
writer
,
NetDefinitionMode
netDefinitionMode
){
if
(
architectureElement
instanceof
CompositeElementSymbol
){
include
((
CompositeElementSymbol
)
architectureElement
,
writer
);
include
((
CompositeElementSymbol
)
architectureElement
,
writer
,
netDefinitionMode
);
}
else
if
(
architectureElement
instanceof
LayerSymbol
){
include
((
LayerSymbol
)
architectureElement
,
writer
);
include
((
LayerSymbol
)
architectureElement
,
writer
,
netDefinitionMode
);
}
else
{
include
((
IOSymbol
)
architectureElement
,
writer
);
include
((
IOSymbol
)
architectureElement
,
writer
,
netDefinitionMode
);
}
}
public
void
include
(
ArchitectureElementSymbol
architectureElement
){
public
void
include
(
ArchitectureElementSymbol
architectureElementSymbol
,
String
netDefinitionMode
)
{
include
(
architectureElementSymbol
,
NetDefinitionMode
.
fromString
(
netDefinitionMode
));
}
public
void
include
(
ArchitectureElementSymbol
architectureElement
,
NetDefinitionMode
netDefinitionMode
){
if
(
writer
==
null
){
throw
new
IllegalStateException
(
"missing writer"
);
}
include
(
architectureElement
,
writer
);
include
(
architectureElement
,
writer
,
netDefinitionMode
);
}
public
Map
.
Entry
<
String
,
String
>
process
(
String
templateNameWithoutEnding
,
Target
targetLanguage
){
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/
mxnet
generator/CNNTrain2
MxNet
.java
→
src/main/java/de/monticore/lang/monticar/cnnarch/
gluon
generator/CNNTrain2
Gluon
.java
View file @
f9a53bef
package
de.monticore.lang.monticar.cnnarch.
mxnet
generator
;
package
de.monticore.lang.monticar.cnnarch.
gluon
generator
;
import
de.monticore.io.paths.ModelPath
;
import
de.monticore.lang.monticar.cnntrain.CNNTrainGenerator
;
...
...
@@ -17,7 +17,7 @@ import java.io.IOException;
import
java.nio.file.Path
;
import
java.util.*
;
public
class
CNNTrain2
MxNet
implements
CNNTrainGenerator
{
public
class
CNNTrain2
Gluon
implements
CNNTrainGenerator
{
private
String
generationTargetPath
;
private
String
instanceName
;
...
...
@@ -58,7 +58,7 @@ public class CNNTrain2MxNet implements CNNTrainGenerator {
}
}
public
CNNTrain2
MxNet
()
{
public
CNNTrain2
Gluon
()
{
setGenerationTargetPath
(
"./target/generated-sources-cnnarch/"
);
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/
mxnet
generator/ConfigurationData.java
→
src/main/java/de/monticore/lang/monticar/cnnarch/
gluon
generator/ConfigurationData.java
View file @
f9a53bef
package
de.monticore.lang.monticar.cnnarch.
mxnet
generator
;
package
de.monticore.lang.monticar.cnnarch.
gluon
generator
;
import
de.monticore.lang.monticar.cnntrain._symboltable.*
;
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/
mxnet
generator/LayerNameCreator.java
→
src/main/java/de/monticore/lang/monticar/cnnarch/
gluon
generator/LayerNameCreator.java
View file @
f9a53bef
...
...
@@ -18,7 +18,7 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package
de.monticore.lang.monticar.cnnarch.
mxnet
generator
;
package
de.monticore.lang.monticar.cnnarch.
gluon
generator
;
import
de.monticore.lang.monticar.cnnarch._symboltable.*
;
import
de.monticore.lang.monticar.cnnarch.predefined.Convolution
;
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/NetDefinitionMode.java
0 → 100644
View file @
f9a53bef
package
de.monticore.lang.monticar.cnnarch.gluongenerator
;
/**
*
*/
public
enum
NetDefinitionMode
{
ARCHITECTURE_DEFINITION
,
FORWARD_FUNCTION
;
public
static
NetDefinitionMode
fromString
(
final
String
netDefinitionMode
)
{
switch
(
netDefinitionMode
)
{
case
"ARCHITECTURE_DEFINITION"
:
return
ARCHITECTURE_DEFINITION
;
case
"FORWARD_FUNCTION"
:
return
FORWARD_FUNCTION
;
default
:
throw
new
IllegalArgumentException
(
"Unknown Net Definition Mode"
);
}
}
}
src/main/java/de/monticore/lang/monticar/cnnarch/
mxnet
generator/Target.java
→
src/main/java/de/monticore/lang/monticar/cnnarch/
gluon
generator/Target.java
View file @
f9a53bef
...
...
@@ -18,7 +18,7 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package
de.monticore.lang.monticar.cnnarch.
mxnet
generator
;
package
de.monticore.lang.monticar.cnnarch.
gluon
generator
;
//can be removed
public
enum
Target
{
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/
mxnet
generator/TemplateConfiguration.java
→
src/main/java/de/monticore/lang/monticar/cnnarch/
gluon
generator/TemplateConfiguration.java
View file @
f9a53bef
...
...
@@ -18,7 +18,7 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package
de.monticore.lang.monticar.cnnarch.
mxnet
generator
;
package
de.monticore.lang.monticar.cnnarch.
gluon
generator
;
import
de.se_rwth.commons.logging.Log
;
import
freemarker.template.Configuration
;
...
...
@@ -38,7 +38,7 @@ public class TemplateConfiguration {
private
TemplateConfiguration
()
{
configuration
=
new
Configuration
(
Configuration
.
VERSION_2_3_23
);
configuration
.
setClassForTemplateLoading
(
TemplateConfiguration
.
class
,
"/templates/
mxnet
/"
);
configuration
.
setClassForTemplateLoading
(
TemplateConfiguration
.
class
,
"/templates/
gluon
/"
);
configuration
.
setDefaultEncoding
(
"UTF-8"
);
configuration
.
setTemplateExceptionHandler
(
TemplateExceptionHandler
.
RETHROW_HANDLER
);
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/
mxnet
generator/TrainParamSupportChecker.java
→
src/main/java/de/monticore/lang/monticar/cnnarch/
gluon
generator/TrainParamSupportChecker.java
View file @
f9a53bef
package
de.monticore.lang.monticar.cnnarch.
mxnet
generator
;
package
de.monticore.lang.monticar.cnnarch.
gluon
generator
;
import
de.monticore.lang.monticar.cnntrain._ast.*
;
import
de.monticore.lang.monticar.cnntrain._visitor.CNNTrainVisitor
;
...
...
src/main/resources/templates/
mxnet
/CNNBufferFile.ftl
→
src/main/resources/templates/
gluon
/CNNBufferFile.ftl
View file @
f9a53bef
File moved
src/main/resources/templates/
mxnet
/CNNCreator.ftl
→
src/main/resources/templates/
gluon
/CNNCreator.ftl
View file @
f9a53bef
...
...
@@ -6,6 +6,9 @@ import shutil
import
h5py
import
sys
import
numpy as np
import
time
from
mxnet import gluon, autograd, nd
from
CNNNet_$
{
tc
.fullArchitectureName
}
import Net
@
mx
.init.register
class
MyConstant(mx.init.Initializer):
...
...
@@ -17,7 +20,6 @@ class MyConstant(mx.init.Initializer):
class
$
{
tc
.fileNameWithoutEnding
}
:
module = None
_data_dir_ = "data/$
{
tc
.fullArchitectureName
}
/"
_model_dir_ = "model/$
{
tc
.fullArchitectureName
}
/"
_model_prefix_ = "$
{
tc
.architectureName
}
"
...
...
@@ -25,6 +27,9 @@ class ${tc.fileNameWithoutEnding}:
_input_shapes_ = [<#list tc.architecture.inputs as input>($
{
tc
.join
(
input
.definition.type.dimensions
,
","
)}
)</#list>]
_output_names_ = [$
{
tc
.join
(
tc
.architectureOutputs
,
","
,
"'"
,
"_label'"
)}
]
def __init__(self):
self.weight_initializer = mx.init.Normal()
self.net = None
def load(self, context):
lastEpoch = 0
...
...
@@ -51,11 +56,7 @@ class ${tc.fileNameWithoutEnding}:
return 0
else:
logging.info("Loading checkpoint: " + param_file)
self.module.load(prefix=self._model_dir_ + self._model_prefix_,
epoch=lastEpoch,
data_names=self._input_names_,
label_names=self._output_names_,
context=context)
self.net.load_parameters(param_file)
return lastEpoch
...
...
@@ -138,11 +139,11 @@ class ${tc.fileNameWithoutEnding}:
train_iter, test_iter, data_mean, data_std = self.load_data(batch_size)
if self.
module
== None:
if self.
net
== None:
if normalize:
self.construct(mx_context, data_mean
, data_std
)
self.construct(
context=
mx_context, data_mean
=nd.array(data_mean), data_std=nd.array(data_std)
)
else:
self.construct(mx_context)
self.construct(
context=
mx_context)
begin_epoch = 0
if load_checkpoint:
...
...
@@ -157,23 +158,79 @@ class ${tc.fileNameWithoutEnding}:
if not os.path.isdir(self._model_dir_):
raise
self.module.fit(
train_data=train_iter,
eval_metric=eval_metric,
eval_data=test_iter,
optimizer=optimizer,
optimizer_params=optimizer_params,
batch_end_callback=mx.callback.Speedometer(batch_size),
epoch_end_callback=mx.callback.do_checkpoint(prefix=self._model_dir_ + self._model_prefix_, period=checkpoint_period),
begin_epoch=begin_epoch,
num_epoch=num_epoch + begin_epoch)
self.module.save_checkpoint(self._model_dir_ + self._model_prefix_, num_epoch + begin_epoch)
self.module.save_checkpoint(self._model_dir_ + self._model_prefix_ + '_newest', 0)
trainer = mx.gluon.Trainer(self.net.collect_params(), optimizer, optimizer_params)
if self.net.last_layer == 'softmax':
loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss()
elif self.net.last_layer == 'sigmoid':
loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
elif self.net.last_layer == 'linear':
loss_function = mx.gluon.loss.L2Loss()
else: # TODO: Change default?
loss_function = mx.gluon.loss.L2Loss()
logging.warning("Invalid last_layer, defaulting to L2 loss")
speed_period = 50
tic = None
for epoch in range(begin_epoch, begin_epoch + num_epoch):
train_iter.reset()
for batch_i, batch in enumerate(train_iter):
data = batch.data[0].as_in_context(mx_context)
label = batch.label[0].as_in_context(mx_context)
with autograd.record():
output = self.net(data)
loss = loss_function(output, label)
loss.backward()
trainer.step(batch_size)
if tic is None:
tic = time.time()
else:
if batch_i % speed_period == 0:
try:
speed = speed_period
*
batch_size / (time.time() - tic)
except ZeroDivisionError:
speed = float("inf")
logging.info("Epoch[%d] Batch[%d] Speed: %.2f samples/sec" % (epoch, batch_i, speed))
tic = time.time()
tic = None
train_iter.reset()
metric = mx.metric.create(eval_metric)
for batch_i, batch in enumerate(train_iter):
data = batch.data[0].as_in_context(mx_context)
label = batch.label[0].as_in_context(mx_context)
output = self.net(data)
predictions = mx.nd.argmax(output, axis=1)
metric.update(preds=predictions, labels=label)
train_metric_score = metric.get()[1]
test_iter.reset()
metric = mx.metric.create(eval_metric)
for batch_i, batch in enumerate(test_iter):
data = batch.data[0].as_in_context(mx_context)
label = batch.label[0].as_in_context(mx_context)
output = self.net(data)
predictions = mx.nd.argmax(output, axis=1)
metric.update(preds=predictions, labels=label)
test_metric_score = metric.get()[1]
logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score))
if (epoch - begin_epoch) % checkpoint_period == 0:
self.net.export(self._model_dir_ + self._model_prefix_, epoch)
self.net.export(self._model_dir_ + self._model_prefix_, num_epoch + begin_epoch)
self.net.export(self._model_dir_ + self._model_prefix_ + '_newest', 0)
def construct(self, context, data_mean=None, data_std=None):
${
tc
.include
(
tc
.architecture.body
)}
self.module = mx.mod.Module(symbol=mx.symbol.Group([$
{
tc
.join
(
tc
.architectureOutputs
,
","
)}
]),
data_names=self._input_names_,
label_names=self._output_names_,
context=context)
self.net = Net(data_mean=data_mean, data_std=data_std)
self.net.collect_params().initialize(self.weight_initializer, ctx=context)
self.net.hybridize()
self.net(mx.nd.zeros((1,)+self._input_shapes_[0], ctx=context))
src/main/resources/templates/gluon/CNNNet.ftl
0 → 100644
View file @
f9a53bef
import
mxnet as mx
import
numpy as np
from
mxnet import gluon
class
Softmax(gluon.HybridBlock):
def __init__(self,
**
kwargs):
super(Softmax, self).__init__(
**
kwargs)
def hybrid_forward(self, F, x):
return F.softmax(x)
class
Split(gluon.HybridBlock):
def __init__(self, num_outputs, axis=1,
**
kwargs):
super(Split, self).__init__(
**
kwargs)
with self.name_scope():
self.axis = axis
self.num_outputs = num_outputs
def hybrid_forward(self, F, x):
return F.split(data=x, axis=self.axis, num_outputs=self.num_outputs)
class
Concatenate(gluon.HybridBlock):
def __init__(self, dim=1,
**
kwargs):
super(Concatenate, self).__init__(
**
kwargs)
with self.name_scope():
self.dim = dim
def hybrid_forward(self, F,
*
x):
return F.concat(
*
x, dim=self.dim)
class
ZScoreNormalization(gluon.HybridBlock):
def __init__(self, data_mean, data_std,
**
kwargs):
super(ZScoreNormalization, self).__init__(
**
kwargs)
with self.name_scope():
self.data_mean = self.params.get('data_mean', shape=data_mean.shape,
init=mx.init.Constant(data_mean.asnumpy().tolist()), differentiable=False)
self.data_std = self.params.get('data_std', shape=data_mean.shape,
init=mx.init.Constant(data_std.asnumpy().tolist()), differentiable=False)
def hybrid_forward(self, F, x, data_mean, data_std):
x = F.broadcast_sub(x, data_mean)
x = F.broadcast_div(x, data_std)
return x
class
Padding(gluon.HybridBlock):
def __init__(self, padding,
**
kwargs):
super(Padding, self).__init__(
**
kwargs)
with self.name_scope():
self.pad_width = padding
def hybrid_forward(self, F, x):
x = F.pad(data=x,
mode='constant',
pad_width=self.pad_width,
constant_value=0)
return x
class
NoNormalization(gluon.HybridBlock):
def __init__(self,
**
kwargs):
super(NoNormalization, self).__init__(
**
kwargs)
def hybrid_forward(self, F, x):
return x
class
Net(gluon.HybridBlock):
def __init__(self, data_mean=None, data_std=None,
**
kwargs):
super(Net, self).__init__(
**
kwargs)
with self.name_scope():
${
tc
.include
(
tc
.architecture.body
,
"ARCHITECTURE_DEFINITION")
}
def hybrid_forward(self, F, x):
${
tc
.include
(
tc
.architecture.body
,
"FORWARD_FUNCTION")
}
\ No newline at end of file
src/main/resources/templates/
mxnet
/CNNPredictor.ftl
→
src/main/resources/templates/
gluon
/CNNPredictor.ftl
View file @
f9a53bef
File moved
src/main/resources/templates/
mxnet
/CNNTrainer.ftl
→
src/main/resources/templates/
gluon
/CNNTrainer.ftl
View file @
f9a53bef
File moved
src/main/resources/templates/
mxnet
/elements/Add.ftl
→
src/main/resources/templates/
gluon
/elements/Add.ftl
View file @
f9a53bef
<#--
TODO: May put this in an extra HybridBlock -->
<#
assign
mode = definition_mode.toString()>
<#
if
mode == "FORWARD_FUNCTION">
$
{
element
.name
}
= $
{
tc
.join
(
element
.inputs
,
" + "
)}
<#
include
"OutputShape.ftl">
\ No newline at end of file
</#
if
>
\ No newline at end of file
src/main/resources/templates/gluon/elements/BatchNorm.ftl
0 → 100644
View file @
f9a53bef
<#
assign
mode = definition_mode.toString()>
<#
assign
input = element.inputs[0]>
<#--
TODO: Find solution for the CNNArch fix_gamma parameter of BatchNorm. Gluon does not provide this parameter-->
<#
if
mode == "ARCHITECTURE_DEFINITION">
self.$
{
element
.name
}
= gluon.nn.BatchNorm()