Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
monticore
EmbeddedMontiArc
generators
CNNArch2X
Commits
80c0e869
Commit
80c0e869
authored
Feb 03, 2022
by
Evgeny Kusmenko
Browse files
Merge branch 'ba-baumann' into 'master'
GNN and DGL support, merge after CNNArchLang See merge request
!28
parents
643960d4
79649129
Pipeline
#639851
passed with stage
in 1 minute and 54 seconds
Changes
8
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/main/java/de/monticore/lang/monticar/cnnarch/generator/ArchitectureElementData.java
View file @
80c0e869
...
...
@@ -314,7 +314,11 @@ public class ArchitectureElementData {
public
int
getValuesDim
(){
return
getLayerSymbol
().
getIntValue
(
AllPredefinedLayers
.
VALUES_DIM_NAME
).
get
();
}
public
int
getNodes
(){
return
getLayerSymbol
().
getIntValue
(
AllPredefinedLayers
.
NODES_NAME
).
get
();
}
@Nullable
public
List
<
Integer
>
getPadding
(){
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNArchTemplateController.java
View file @
80c0e869
...
...
@@ -106,6 +106,7 @@ public abstract class CNNArchTemplateController {
public
boolean
containsAdaNet
(){
return
this
.
architecture
.
containsAdaNet
();
}
public
String
getName
(
ArchitectureElementSymbol
layer
){
return
nameManager
.
getName
(
layer
);
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/generator/ConfigurationData.java
View file @
80c0e869
...
...
@@ -6,6 +6,7 @@ import com.google.common.collect.Maps;
import
de.monticore.lang.monticar.cnnarch.generator.annotations.ArchitectureAdapter
;
import
de.monticore.lang.monticar.cnnarch.generator.annotations.Range
;
import
de.monticore.lang.monticar.cnnarch.generator.training.RlAlgorithm
;
import
de.monticore.lang.monticar.cnnarch.generator.training.NetworkType
;
import
de.monticore.lang.monticar.cnnarch.generator.training.TrainingComponentsContainer
;
import
de.monticore.lang.monticar.cnnarch.generator.training.TrainingConfiguration
;
...
...
@@ -98,6 +99,21 @@ public abstract class ConfigurationData {
return
normalizeOpt
.
orElse
(
null
);
}
public
Boolean
getMultiGraph
()
{
Optional
<
Boolean
>
multiGraphOpt
=
trainingConfiguration
.
getMultiGraph
();
return
multiGraphOpt
.
orElse
(
null
);
}
public
List
<
Integer
>
getTrainMask
()
{
Optional
<
List
<
Integer
>>
trainMaskOpt
=
trainingConfiguration
.
getTrainMask
();
return
trainMaskOpt
.
orElse
(
null
);
}
public
List
<
Integer
>
getTestMask
()
{
Optional
<
List
<
Integer
>>
testMaskOpt
=
trainingConfiguration
.
getTestMask
();
return
testMaskOpt
.
orElse
(
null
);
}
public
Boolean
getShuffleData
()
{
Optional
<
Boolean
>
shuffleDataOpt
=
trainingConfiguration
.
getShuffleData
();
return
shuffleDataOpt
.
orElse
(
null
);
...
...
@@ -379,6 +395,16 @@ public abstract class ConfigurationData {
return
DQN
;
}
public
String
getNetworkType
()
{
Optional
<
NetworkType
>
networkTypeOpt
=
trainingConfiguration
.
getNetworkType
();
NetworkType
networkType
=
networkTypeOpt
.
get
();
if
(
networkType
.
equals
(
NetworkType
.
GNN
))
{
return
GNN
;
}
return
null
;
}
// protected Object getDefaultValueOrElse(String parameterKey, Object elseValue) {
// if (schema == null) {
// return elseValue;
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/generator/training/NetworkType.java
0 → 100644
View file @
80c0e869
package
de.monticore.lang.monticar.cnnarch.generator.training
;
public
enum
NetworkType
{
GNN
(
"gnn"
);
String
type
;
NetworkType
(
String
type
)
{
this
.
type
=
type
;
}
public
static
NetworkType
networkType
(
String
type
)
{
for
(
NetworkType
nt
:
values
())
{
if
(
nt
.
type
.
equals
(
type
))
{
return
nt
;
}
}
throw
new
IllegalArgumentException
(
String
.
valueOf
(
type
));
}
public
String
getType
()
{
return
type
;
}
}
\ No newline at end of file
src/main/java/de/monticore/lang/monticar/cnnarch/generator/training/TrainingConfiguration.java
View file @
80c0e869
...
...
@@ -107,6 +107,15 @@ public class TrainingConfiguration {
return
Optional
.
of
(
RlAlgorithm
.
rlAlgorithm
(
rlAlgorithm
));
}
public
Optional
<
NetworkType
>
getNetworkType
()
{
Optional
<
ConfigurationEntry
>
networkTypeOpt
=
configurationSymbol
.
getConfigurationEntry
(
NETWORK_TYPE
);
if
(!
networkTypeOpt
.
isPresent
())
{
return
Optional
.
empty
();
}
String
networkType
=
(
String
)
networkTypeOpt
.
get
().
getValue
();
return
Optional
.
of
(
NetworkType
.
networkType
(
networkType
));
}
public
Optional
<
Integer
>
getBatchSize
()
{
return
getParameterValue
(
BATCH_SIZE
);
}
...
...
@@ -143,6 +152,18 @@ public class TrainingConfiguration {
return
getParameterValue
(
SHUFFLE_DATA
);
}
public
Optional
<
Boolean
>
getMultiGraph
()
{
return
getParameterValue
(
MULTI_GRAPH
);
}
public
Optional
<
List
<
Integer
>>
getTrainMask
()
{
return
getParameterValue
(
TRAIN_MASK
);
}
public
Optional
<
List
<
Integer
>>
getTestMask
()
{
return
getParameterValue
(
TEST_MASK
);
}
public
Optional
<
Double
>
getClipGlobalGradNorm
()
{
return
getParameterValue
(
CLIP_GLOBAL_GRAD_NORM
);
}
...
...
src/main/java/de/monticore/lang/monticar/cnnarch/generator/training/TrainingParameterConstants.java
View file @
80c0e869
...
...
@@ -42,6 +42,7 @@ public class TrainingParameterConstants {
public
static
final
String
NORMALIZE
=
"normalize"
;
public
static
final
String
CONTEXT
=
"context"
;
public
static
final
String
SHUFFLE_DATA
=
"shuffle_data"
;
public
static
final
String
CLIP_GLOBAL_GRAD_NORM
=
"clip_global_grad_norm"
;
public
static
final
String
USE_TEACHER_FORCING
=
"use_teacher_forcing"
;
public
static
final
String
SAVE_ATTENTION_IMAGE
=
"save_attention_image"
;
...
...
@@ -57,6 +58,12 @@ public class TrainingParameterConstants {
public
static
final
String
DDPG
=
"ddpg"
;
public
static
final
String
TD3
=
"td3"
;
public
static
final
String
MULTI_GRAPH
=
"multi_graph"
;
public
static
final
String
TRAIN_MASK
=
"train_mask"
;
public
static
final
String
TEST_MASK
=
"test_mask"
;
public
static
final
String
GNN
=
"gnn"
;
public
static
final
String
NETWORK_TYPE
=
"network_type"
;
public
static
final
String
LEARNING_METHOD
=
"learning_method"
;
public
static
final
String
EVAL_METRIC
=
"eval_metric"
;
public
static
final
String
NUM_EPISODES
=
"num_episodes"
;
...
...
src/main/resources/schemas/GNN.scm
0 → 100644
View file @
80c0e869
/*
(
c
)
https://github
.
com/MontiCore/monticore
*/
schema
GNN
extends
Supervised
{
train_mask:
Z*
test_mask:
Z*
multi_graph:
B
}
src/main/resources/schemas/Supervised.scm
View file @
80c0e869
...
...
@@ -4,6 +4,10 @@ import Loss;
schema
Supervised
extends
General
{
network_type:
schema
{
gnn
;
}
batch_size:
N1
num_epoch:
N1
normalize:
B
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment