Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
monticore
EmbeddedMontiArc
generators
CNNArch2X
Commits
a57778c1
Commit
a57778c1
authored
Feb 16, 2022
by
Dmytro Semenchenko
Browse files
Merge branch 'master' into onnx-dmytro
parents
69e2a60a
65a4d7a1
Pipeline
#649193
passed with stage
in 1 minute and 24 seconds
Changes
16
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
pom.xml
View file @
a57778c1
...
...
@@ -245,3 +245,4 @@
</snapshotRepository>
</distributionManagement>
</project>
src/main/java/de/monticore/lang/monticar/cnnarch/generator/ArchitectureElementData.java
View file @
a57778c1
...
...
@@ -90,6 +90,10 @@ public class ArchitectureElementData {
return
getLayerSymbol
().
getIntTupleValue
(
AllPredefinedLayers
.
STRIDE_NAME
).
get
();
}
public
String
getPdf
()
{
return
getLayerSymbol
().
getStringValue
(
AllPredefinedLayers
.
PDF_NAME
).
get
();
}
public
int
getNumEmbeddings
()
{
return
getLayerSymbol
().
getIntValue
(
AllPredefinedLayers
.
NUM_EMBEDDINGS_NAME
).
get
();
}
public
int
getGroups
(){
return
getLayerSymbol
().
getIntValue
(
AllPredefinedLayers
.
GROUPS_NAME
).
get
();
}
...
...
@@ -314,7 +318,11 @@ public class ArchitectureElementData {
public
int
getValuesDim
(){
return
getLayerSymbol
().
getIntValue
(
AllPredefinedLayers
.
VALUES_DIM_NAME
).
get
();
}
public
int
getNodes
(){
return
getLayerSymbol
().
getIntValue
(
AllPredefinedLayers
.
NODES_NAME
).
get
();
}
@Nullable
public
List
<
Integer
>
getPadding
(){
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNArchTemplateController.java
View file @
a57778c1
...
...
@@ -106,6 +106,7 @@ public abstract class CNNArchTemplateController {
public
boolean
containsAdaNet
(){
return
this
.
architecture
.
containsAdaNet
();
}
public
String
getName
(
ArchitectureElementSymbol
layer
){
return
nameManager
.
getName
(
layer
);
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/generator/ConfigurationData.java
View file @
a57778c1
...
...
@@ -6,6 +6,7 @@ import com.google.common.collect.Maps;
import
de.monticore.lang.monticar.cnnarch.generator.annotations.ArchitectureAdapter
;
import
de.monticore.lang.monticar.cnnarch.generator.annotations.Range
;
import
de.monticore.lang.monticar.cnnarch.generator.training.RlAlgorithm
;
import
de.monticore.lang.monticar.cnnarch.generator.training.NetworkType
;
import
de.monticore.lang.monticar.cnnarch.generator.training.TrainingComponentsContainer
;
import
de.monticore.lang.monticar.cnnarch.generator.training.TrainingConfiguration
;
...
...
@@ -39,6 +40,10 @@ public abstract class ConfigurationData {
return
trainingConfiguration
.
isGanLearning
();
}
public
Boolean
isVaeLearning
()
{
return
trainingConfiguration
.
isVaeLearning
();
}
public
Boolean
isReinforcementLearning
()
{
return
trainingConfiguration
.
isReinforcementLearning
();
}
...
...
@@ -73,6 +78,16 @@ public abstract class ConfigurationData {
return
loadPretrainedOpt
.
orElse
(
null
);
}
public
Double
getKlLossWeight
()
{
Optional
<
Double
>
klLossWeightOpt
=
trainingConfiguration
.
getKlLossWeight
();
return
klLossWeightOpt
.
orElse
(
null
);
}
public
String
getReconLossName
()
{
Optional
<
String
>
reconLossNameOpt
=
trainingConfiguration
.
getReconLossName
();
return
reconLossNameOpt
.
orElse
(
null
);
}
// COMPARE WITH CNNTRAIN IMPLEMENTATION IN GluonConfigurationData
public
Boolean
getPreprocessor
()
{
Optional
<
String
>
preprocessorOpt
=
trainingConfiguration
.
getPreprocessor
();
...
...
@@ -89,6 +104,21 @@ public abstract class ConfigurationData {
return
onnxExport
.
orElse
(
null
);
}
public
Boolean
getMultiGraph
()
{
Optional
<
Boolean
>
multiGraphOpt
=
trainingConfiguration
.
getMultiGraph
();
return
multiGraphOpt
.
orElse
(
null
);
}
public
List
<
Integer
>
getTrainMask
()
{
Optional
<
List
<
Integer
>>
trainMaskOpt
=
trainingConfiguration
.
getTrainMask
();
return
trainMaskOpt
.
orElse
(
null
);
}
public
List
<
Integer
>
getTestMask
()
{
Optional
<
List
<
Integer
>>
testMaskOpt
=
trainingConfiguration
.
getTestMask
();
return
testMaskOpt
.
orElse
(
null
);
}
public
Boolean
getShuffleData
()
{
Optional
<
Boolean
>
shuffleDataOpt
=
trainingConfiguration
.
getShuffleData
();
return
shuffleDataOpt
.
orElse
(
null
);
...
...
@@ -352,7 +382,11 @@ public abstract class ConfigurationData {
// public Map<String, Map<String, Object>> getConstraintLosses() { // TODO
// return getMultiParamMapEntry(CONSTRAINT_LOSS, "name");
// }
public
String
getSelfPlay
()
{
// added Parameter self_play for cooperative driving
Optional
<
String
>
selfPlay
=
trainingConfiguration
.
getSelfPlay
();
return
selfPlay
.
orElse
(
null
);
}
public
String
getRlAlgorithm
()
{
Optional
<
RlAlgorithm
>
rlAlgorithmOpt
=
trainingConfiguration
.
getRlAlgorithm
();
if
(!
rlAlgorithmOpt
.
isPresent
())
{
...
...
@@ -370,6 +404,16 @@ public abstract class ConfigurationData {
return
DQN
;
}
public
String
getNetworkType
()
{
Optional
<
NetworkType
>
networkTypeOpt
=
trainingConfiguration
.
getNetworkType
();
NetworkType
networkType
=
networkTypeOpt
.
get
();
if
(
networkType
.
equals
(
NetworkType
.
GNN
))
{
return
GNN
;
}
return
null
;
}
// protected Object getDefaultValueOrElse(String parameterKey, Object elseValue) {
// if (schema == null) {
// return elseValue;
...
...
@@ -631,4 +675,4 @@ public abstract class ConfigurationData {
}
return
object
.
toString
();
}
}
\ No newline at end of file
}
src/main/java/de/monticore/lang/monticar/cnnarch/generator/training/LearningMethod.java
View file @
a57778c1
...
...
@@ -4,7 +4,8 @@ public enum LearningMethod {
SUPERVISED
(
"supervised"
),
REINFORCEMENT
(
"reinforcement"
),
GAN
(
"gan"
);
GAN
(
"gan"
),
VAE
(
"vae"
);
String
method
;
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/generator/training/NetworkType.java
0 → 100644
View file @
a57778c1
package
de.monticore.lang.monticar.cnnarch.generator.training
;
public
enum
NetworkType
{
GNN
(
"gnn"
);
String
type
;
NetworkType
(
String
type
)
{
this
.
type
=
type
;
}
public
static
NetworkType
networkType
(
String
type
)
{
for
(
NetworkType
nt
:
values
())
{
if
(
nt
.
type
.
equals
(
type
))
{
return
nt
;
}
}
throw
new
IllegalArgumentException
(
String
.
valueOf
(
type
));
}
public
String
getType
()
{
return
type
;
}
}
\ No newline at end of file
src/main/java/de/monticore/lang/monticar/cnnarch/generator/training/TrainingComponentsContainer.java
View file @
a57778c1
...
...
@@ -20,6 +20,8 @@ public class TrainingComponentsContainer {
private
ArchitectureAdapter
actorNetwork
;
private
ArchitectureAdapter
criticNetwork
;
private
ArchitectureAdapter
generatorNetwork
;
private
ArchitectureAdapter
encoderNetwork
;
private
ArchitectureAdapter
decoderNetwork
;
private
ArchitectureAdapter
discriminatorNetwork
;
private
ArchitectureAdapter
qNetwork
;
private
EMAComponentInstanceSymbol
rewardFunction
;
...
...
@@ -43,11 +45,18 @@ public class TrainingComponentsContainer {
return
Optional
.
ofNullable
(
generatorNetwork
);
}
public
Optional
<
ArchitectureAdapter
>
getDiscriminatorNetwork
()
{
return
Optional
.
ofNullable
(
discriminatorNetwork
);
}
public
Optional
<
ArchitectureAdapter
>
getDecoderNetwork
()
{
return
Optional
.
ofNullable
(
decoderNetwork
);
}
public
Optional
<
ArchitectureAdapter
>
getEncoderNetwork
()
{
return
Optional
.
ofNullable
(
encoderNetwork
);
}
public
Optional
<
ArchitectureAdapter
>
getQNetwork
()
{
return
Optional
.
ofNullable
(
qNetwork
);
}
...
...
@@ -80,6 +89,8 @@ public class TrainingComponentsContainer {
}
}
else
if
(
trainingConfiguration
.
isGanLearning
())
{
setGeneratorNetwork
(
trainedArchitecture
);
}
else
if
(
trainingConfiguration
.
isVaeLearning
())
{
setDecoderNetwork
(
trainedArchitecture
);
}
}
...
...
@@ -108,6 +119,16 @@ public class TrainingComponentsContainer {
addTrainingComponent
(
QNETWORK
,
qNetwork
);
}
public
void
setDecoderNetwork
(
ArchitectureAdapter
decoderNetwork
)
{
this
.
decoderNetwork
=
decoderNetwork
;
addTrainingComponent
(
DECODER
,
decoderNetwork
);
}
public
void
setEncoderNetwork
(
ArchitectureAdapter
encoderNetwork
)
{
this
.
encoderNetwork
=
encoderNetwork
;
addTrainingComponent
(
ENCODER
,
encoderNetwork
);
}
public
void
setRewardFunction
(
EMAComponentInstanceSymbol
rewardFunction
)
{
this
.
rewardFunction
=
rewardFunction
;
addTrainingComponent
(
REWARD_FUNCTION
,
rewardFunction
.
getComponentType
());
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/generator/training/TrainingConfiguration.java
View file @
a57778c1
...
...
@@ -7,10 +7,7 @@ import conflang._symboltable.ConfigurationEntrySymbol;
import
conflang._symboltable.NestedConfigurationEntrySymbol
;
import
schemalang._symboltable.SchemaDefinitionSymbol
;
import
java.util.Collection
;
import
java.util.List
;
import
java.util.Map
;
import
java.util.Optional
;
import
java.util.*
;
import
static
de
.
monticore
.
lang
.
monticar
.
cnnarch
.
generator
.
training
.
TrainingParameterConstants
.*;
...
...
@@ -50,6 +47,10 @@ public class TrainingConfiguration {
return
getParameterValue
(
CONTEXT
);
}
public
Optional
<
String
>
getSelfPlay
()
{
return
getParameterValue
(
SELF_PLAY
);
}
public
Optional
<
LearningMethod
>
getLearningMethod
()
{
Optional
<
ConfigurationEntry
>
learningMethodOpt
=
configurationSymbol
.
getConfigurationEntry
(
LEARNING_METHOD
);
...
...
@@ -83,6 +84,15 @@ public class TrainingConfiguration {
return
LearningMethod
.
GAN
.
equals
(
learningMethod
);
}
public
boolean
isVaeLearning
()
{
Optional
<
LearningMethod
>
learningMethodOpt
=
getLearningMethod
();
if
(!
learningMethodOpt
.
isPresent
())
{
return
false
;
// Not correct here to return false..
}
LearningMethod
learningMethod
=
learningMethodOpt
.
get
();
return
LearningMethod
.
VAE
.
equals
(
learningMethod
);
}
public
boolean
isReinforcementLearning
()
{
Optional
<
LearningMethod
>
learningMethodOpt
=
getLearningMethod
();
if
(!
learningMethodOpt
.
isPresent
())
{
...
...
@@ -101,6 +111,15 @@ public class TrainingConfiguration {
return
Optional
.
of
(
RlAlgorithm
.
rlAlgorithm
(
rlAlgorithm
));
}
public
Optional
<
NetworkType
>
getNetworkType
()
{
Optional
<
ConfigurationEntry
>
networkTypeOpt
=
configurationSymbol
.
getConfigurationEntry
(
NETWORK_TYPE
);
if
(!
networkTypeOpt
.
isPresent
())
{
return
Optional
.
empty
();
}
String
networkType
=
(
String
)
networkTypeOpt
.
get
().
getValue
();
return
Optional
.
of
(
NetworkType
.
networkType
(
networkType
));
}
public
Optional
<
Integer
>
getBatchSize
()
{
return
getParameterValue
(
BATCH_SIZE
);
}
...
...
@@ -141,6 +160,18 @@ public class TrainingConfiguration {
return
getParameterValue
(
SHUFFLE_DATA
);
}
public
Optional
<
Boolean
>
getMultiGraph
()
{
return
getParameterValue
(
MULTI_GRAPH
);
}
public
Optional
<
List
<
Integer
>>
getTrainMask
()
{
return
getParameterValue
(
TRAIN_MASK
);
}
public
Optional
<
List
<
Integer
>>
getTestMask
()
{
return
getParameterValue
(
TEST_MASK
);
}
public
Optional
<
Double
>
getClipGlobalGradNorm
()
{
return
getParameterValue
(
CLIP_GLOBAL_GRAD_NORM
);
}
...
...
@@ -353,6 +384,20 @@ public class TrainingConfiguration {
return
getObjectParameterParameters
(
DISCRIMINATOR_OPTIMIZER
);
}
public
Boolean
hasEncoderName
()
{
return
hasParameter
(
ENCODER
);
}
public
Optional
<
String
>
getEncoderName
()
{
return
getObjectParameterValue
(
ENCODER
);
}
public
Optional
<
String
>
getReconLossName
()
{
return
getObjectParameterValue
(
RECON_LOSS
);
}
public
Optional
<
Double
>
getKlLossWeight
()
{
return
getParameterValue
(
KL_LOSS_WEIGHT
);
}
public
boolean
hasStrategy
()
{
return
hasParameter
(
STRATEGY
);
}
...
...
@@ -513,4 +558,5 @@ public class TrainingConfiguration {
}
return
keyValues
;
}
}
\ No newline at end of file
src/main/java/de/monticore/lang/monticar/cnnarch/generator/training/TrainingParameterConstants.java
View file @
a57778c1
...
...
@@ -15,6 +15,7 @@ public class TrainingParameterConstants {
public
static
final
String
SUPERVISED
=
"supervised"
;
public
static
final
String
REINFORCEMENT
=
"reinforcement"
;
public
static
final
String
GAN
=
"gan"
;
public
static
final
String
VAE
=
"vae"
;
/*
* Optimizers
...
...
@@ -41,6 +42,7 @@ public class TrainingParameterConstants {
public
static
final
String
NORMALIZE
=
"normalize"
;
public
static
final
String
CONTEXT
=
"context"
;
public
static
final
String
SHUFFLE_DATA
=
"shuffle_data"
;
public
static
final
String
CLIP_GLOBAL_GRAD_NORM
=
"clip_global_grad_norm"
;
public
static
final
String
USE_TEACHER_FORCING
=
"use_teacher_forcing"
;
public
static
final
String
SAVE_ATTENTION_IMAGE
=
"save_attention_image"
;
...
...
@@ -56,6 +58,13 @@ public class TrainingParameterConstants {
public
static
final
String
DQN
=
"dqn"
;
public
static
final
String
DDPG
=
"ddpg"
;
public
static
final
String
TD3
=
"td3"
;
public
static
final
String
SELF_PLAY
=
"self_play"
;
public
static
final
String
MULTI_GRAPH
=
"multi_graph"
;
public
static
final
String
TRAIN_MASK
=
"train_mask"
;
public
static
final
String
TEST_MASK
=
"test_mask"
;
public
static
final
String
GNN
=
"gnn"
;
public
static
final
String
NETWORK_TYPE
=
"network_type"
;
public
static
final
String
LEARNING_METHOD
=
"learning_method"
;
public
static
final
String
EVAL_METRIC
=
"eval_metric"
;
...
...
@@ -111,4 +120,9 @@ public class TrainingParameterConstants {
public
static
final
String
GENERATOR_LOSS_WEIGHT
=
"generator_loss_weight"
;
public
static
final
String
DISCRIMINATOR_LOSS_WEIGHT
=
"discriminator_loss_weight"
;
public
static
final
String
PRINT_IMAGES
=
"print_images"
;
public
static
final
String
ENCODER
=
"encoder"
;
public
static
final
String
DECODER
=
"decoder"
;
public
static
final
String
KL_LOSS_WEIGHT
=
"kl_loss_weight"
;
public
static
final
String
RECON_LOSS
=
"reconstruction_loss"
;
}
\ No newline at end of file
src/main/resources/schemas/GNN.scm
0 → 100644
View file @
a57778c1
/*
(
c
)
https://github
.
com/MontiCore/monticore
*/
schema
GNN
extends
Supervised
{
train_mask:
Z*
test_mask:
Z*
multi_graph:
B
}
src/main/resources/schemas/General.scm
View file @
a57778c1
...
...
@@ -4,7 +4,7 @@ import Optimizer;
schema
General
{
learning_method
=
supervised:
schema
{
supervised,
reinforcement,
gan
;
supervised,
reinforcement,
gan
,
vae
;
}
context:
enum
{
...
...
src/main/resources/schemas/Reinforcement.scm
View file @
a57778c1
...
...
@@ -8,6 +8,10 @@ schema Reinforcement extends General {
dqn,
ddpg,
td3
;
}
self_play:
enum
{
no,
yes
;
}
agent_name:
string
num_episodes
=
50
:
N1
num_max_steps
=
99999
:
N
...
...
@@ -21,4 +25,4 @@ schema Reinforcement extends General {
actor_optimizer:
optimizer_type
environment:
environment_type!
replay_memory
=
buffer:
replay_memory_type
}
\ No newline at end of file
}
src/main/resources/schemas/Supervised.scm
View file @
a57778c1
...
...
@@ -4,6 +4,10 @@ import Loss;
schema
Supervised
extends
General
{
network_type:
schema
{
gnn
;
}
batch_size:
N1
num_epoch:
N
normalize:
B
...
...
src/main/resources/schemas/VAE.scm
0 → 100644
View file @
a57778c1
/*
(
c
)
https://github
.
com/MontiCore/monticore
*/
import
Optimizer
;
schema
VAE
extends
General
{
reference-model:
referencemodels
.
vae
.
VAE,
referencemodels
.
vae
.
CVAE
batch_size:
N1
num_epoch:
N1
normalize:
B
checkpoint_period
=
5
:
N
load_checkpoint:
B
load_pretrained:
B
log_period:
N
reconstruction_loss
=
mse:
reconLoss_type
print_images
=
false:
B
kl_loss_weight:
Q
reconLoss_type
{
values:
bce,
mse
;
}
}
\ No newline at end of file
src/main/resources/schemas/referencemodels/vae/CVAE.ema
0 → 100644
View file @
a57778c1
/*
(
c
)
https
://
github
.
com
/
MontiCore
/
monticore
*/
package
referencemodels
.
vae
;
component
CVAE
{
component
Encoder
{
ports
in
X
data
,
in
W
^{
1
}
label
,
out
D
encoding
;
}
component
Decoder
{
ports
in
D
encoding
,
in
W
^{
1
}
label
,
out
X
data
;
}
instance
Encoder
encoder
;
instance
Decoder
decoder
;
connect
encoder
.
encoding
->
decoder
.
encoding
;
}
\ No newline at end of file
src/main/resources/schemas/referencemodels/vae/VAE.ema
0 → 100644
View file @
a57778c1
/*
(
c
)
https
://
github
.
com
/
MontiCore
/
monticore
*/
package
referencemodels
.
vae
;
component
VAE
{
component
Encoder
{
ports
in
X
data
,
out
D
encoding
;
}
component
Decoder
{
ports
in
D
encoding
,
out
X
data
;
}
instance
Encoder
encoder
;
instance
Decoder
decoder
;
connect
encoder
.
encoding
->
decoder
.
encoding
;
}
\ No newline at end of file
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment