Commit 5e7ce0fd authored by Mattis Hoppe's avatar Mattis Hoppe
Browse files

graphic interface for RL

parent 6b0bdb62
Pipeline #493730 passed with stage
in 5 minutes and 15 seconds
......@@ -165,8 +165,6 @@ public class RLAutopilot extends EEComponent implements Inspectable {
turnOutput = (double) action[2] * 30;
}
//TODO set state array from state values
//is this. needed?
setState();
setSteering(sendTime, turnOutput);
setGas(sendTime, speedOutput);
......
......@@ -330,6 +330,7 @@ public class PathfindingImpl implements Pathfinding {
// Add all completely traversed segments
do {
lastEdgeID = vertexPredecessorEdge[s.startNodeID];
if(lastEdgeID == -1) return new Path(0);
s = world.getWaySegment(lastEdgeID);
if (lastEdgeID == startRef.roadSegmentID || lastEdgeID == startRef.reverseId) break; // Reached start segment
inc = s.pointsStart < s.pointsEnd ? -1 : 1;
......
......@@ -27,18 +27,11 @@ public class RLRewardCalculator{
public float getReward(){
//fixed layout for state[] in following order: traj_x,traj_y,traj_length, currentpos_x and currentpos_y, current_compass, current_velocity
//get current positon of each vehicle
//get trajectory of each vehicle
float reward = 0;
for(int i = 0; i<truePositions.length; i++){
reward += getRewardForVehicle(i);
}
//TODO calculate lines to drive
//TODO take squared error as negative reward
//TODO also punish driving slowly
return reward;
}
......
......@@ -5,16 +5,23 @@ import java.time.Instant;
import de.rwth.montisim.commons.simulation.TaskStatus;
import de.rwth.montisim.commons.simulation.TimeUpdate;
import de.rwth.montisim.commons.map.Pathfinding;
import de.rwth.montisim.commons.utils.Coordinates;
import de.rwth.montisim.commons.utils.json.SerializationException;
import de.rwth.montisim.commons.utils.Vec2;
import de.rwth.montisim.simulation.environment.pathfinding.PathfindingImpl;
import de.rwth.montisim.simulation.environment.osmmap.*;
import de.rwth.montisim.simulation.environment.world.World;
import de.rwth.montisim.simulation.simulator.SimulationConfig;
import de.rwth.montisim.simulation.simulator.Simulator;
import de.rwth.montisim.simulation.eecomponents.autopilots.*;
import de.rwth.montisim.simulation.vehicle.task.TaskProperties;
import de.rwth.montisim.simulation.vehicle.task.metric.MetricGoalProperties;
import de.rwth.montisim.simulation.vehicle.task.path.PathGoalProperties;
import de.rwth.montisim.simulation.vehicle.Vehicle;
import de.rwth.montisim.simulation.vehicle.VehicleProperties;
import de.rwth.montisim.simulation.environment.osmmap.*;
import de.rwth.montisim.simulation.vehicle.navigation.Navigation;
import de.rwth.montisim.simulation.simulator.visualization.rl.RLVisualizer;
import org.ros.message.MessageListener;
......@@ -31,10 +38,13 @@ import java.io.File;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.ConcurrentModificationException;
public class RLSimulationHandler extends AbstractNodeMain{
private Simulator simulator = null;
public Simulator simulator = null;
private SimulationConfig config;
Instant simulationTime;
private OsmMap map;
......@@ -42,10 +52,14 @@ public class RLSimulationHandler extends AbstractNodeMain{
private Pathfinding pathfinding;
private Collection<Vehicle> vehicles;
private Vehicle[] vehiclesArray;
private RLAutopilot [] autopilots = new RLAutopilot[1];
private RLAutopilot[] autopilots;
private Navigation[] navigations;
private RLRewardCalculator rewardCalc;
private long seed = 48965l;
private Random rndGen = new Random();
private boolean distributed = false;
private int activeVehicle = 0; //decides which car is active for distributed learning
private RLVisualizer viz;
private boolean in_termination = true;
private boolean in_reset = false;
......@@ -55,7 +69,7 @@ public class RLSimulationHandler extends AbstractNodeMain{
private class Result{
float reward = 0f;
float[] state; //standard dimension: [6][10]
float[] state = null;
boolean terminated = true;
public Result (float reward, float[] state, boolean terminated){
......@@ -65,12 +79,15 @@ public class RLSimulationHandler extends AbstractNodeMain{
}
}
public RLSimulationHandler (SimulationConfig config, Instant simulationTime, World world, Pathfinding pathfinding, OsmMap map){
public RLSimulationHandler (SimulationConfig config, Instant simulationTime, World world, Pathfinding pathfinding, OsmMap map, RLVisualizer viz){
this.config = config;
this.simulationTime = simulationTime;
this.world = world;
this.pathfinding = pathfinding;
this.map = map;
this.viz = viz;
rndGen.setSeed(seed);
//System.out.println(pathfinding == null);
}
@Override
......@@ -103,16 +120,18 @@ public class RLSimulationHandler extends AbstractNodeMain{
action_subscriber.addMessageListener(new MessageListener<std_msgs.Float32MultiArray>() {
@Override
public void onNewMessage(std_msgs.Float32MultiArray action) {
if(in_termination || in_reset) return; //can add debug information later
Result result = step(action.getData());
//wait_done();
if(in_termination || in_reset) return;
Result result = step(action.getData());
std_msgs.Float32MultiArray state = state_publisher.newMessage();
state.setData(result.state);
std_msgs.Bool terminated = terminate_publisher.newMessage();
terminated.setData(result.terminated);
std_msgs.Float32 reward = reward_publisher.newMessage();
reward.setData(result.reward);
if(result.terminated) in_termination = true;
state_publisher.publish(state);
terminate_publisher.publish(terminated);
reward_publisher.publish(reward);
......@@ -127,8 +146,7 @@ public class RLSimulationHandler extends AbstractNodeMain{
if(reset.getData() && in_termination){
in_reset = true;
Result result = reset();
//wait_done();
//Thread.sleep(200);
std_msgs.Float32MultiArray state = state_publisher.newMessage();
state.setData(result.state);
std_msgs.Bool terminated = terminate_publisher.newMessage();
......@@ -163,52 +181,83 @@ public class RLSimulationHandler extends AbstractNodeMain{
private Result setup(){
if(simulator != null){
//wait_done();
simulator.destroy();
simulator = null;
}
//add randomization method here
randomizeScenario();
if(viz!= null) viz.clearRenderer();
//System.out.println(pathfinding == null);
simulator = config.build(world, pathfinding, map);
vehicles = simulator.getVehicles();
//need to get autopilots from vehicle
vehiclesArray = vehicles.toArray(new Vehicle[0]);
//for now only add one autopilot later method with more than one
//check type of autopilot and not null and send custom error message
//check with if whether null or not and type TODO
autopilots[0] = (RLAutopilot) vehiclesArray[0].eesystem.getComponent("RLAutopilot").get();
if(viz != null){
viz.simTime = simulationTime;
viz.setup();
}
autopilots = new RLAutopilot[vehiclesArray.length];
for(int i = 0; i<autopilots.length; i++){
autopilots[i] = (RLAutopilot) vehiclesArray[i].eesystem.getComponent("RLAutopilot").get();
}
navigations = new Navigation[autopilots.length];
for(int i = 0;i<navigations.length;i++){
navigations[i] = (Navigation) vehiclesArray[i].eesystem.getComponent("Navigation").get();
}
//setup ready, now one step is needed so physical values are assigned
while(autopilots[0].state == null){
//update simulation until all values are assigned
while(anyStateNull()){
TimeUpdate tu = new TimeUpdate(simulationTime, config.tick_duration);
simulator.update(tu);
simulationTime = tu.newTime;
}
//TODO get values to give to emadl autopilot from custom java autopilot
float[] simState = autopilots[0].state;
activeVehicle = 0;
float[] simState = getState();
rewardCalc = new RLRewardCalculator(navigations, vehiclesArray);
float init_reward = rewardCalc.getReward();
boolean simTermination = this.getSimTermination();
done = true;
//done = true;
//check if all vehicles found a path, if not, restart simulation
for(int i = 0; i<navigations.length; i++){
System.out.println(navigations[i].getCurrentPath().get().getLength());
if(navigations[i].getCurrentPath().get().getLength() == 0)
return setup();
}
//if(viz!=null){
// viz.redraw();
// viz.simTime = simulationTime;
//}
return new Result(init_reward, simState, simTermination);
}
private Result step(float[] action){
autopilots[0].action = action; //maybe divide action by number of vehicles to determine which goes to which
//vehicles not ordered...
TimeUpdate tu = new TimeUpdate(simulationTime, config.tick_duration);
simulator.update(tu);
simulationTime = tu.newTime;
float[] simState = autopilots[0].state;
//with this method the computation is always one step delayed
setAction(action);
//TimeUpdate tu = new TimeUpdate(simulationTime, config.tick_duration);
if(activeVehicle == vehiclesArray.length - 1 || !distributed){
TimeUpdate tu = new TimeUpdate(simulationTime, config.tick_duration);
simulator.update(tu);
simulationTime = tu.newTime;
}
float[] simState = getState();
float step_reward = rewardCalc.getReward();
boolean simTermination = this.getSimTermination();
activeVehicle = (activeVehicle + 1)%vehiclesArray.length;
if(viz != null){
try{viz.simTime = simulationTime;
viz.redraw();
//viz.viewer.repaint();
}
catch(ConcurrentModificationException ignore) {}
}
done = true;
return new Result(step_reward, simState, simTermination);
}
......@@ -223,7 +272,97 @@ public class RLSimulationHandler extends AbstractNodeMain{
if(simulator.status() == TaskStatus.RUNNING) return false;
else return true;
}
//for more than one vehicle: public float[][] combineActions(){}
//get combined state of all vehicles
private float[] getState(){
int vehicleCount = autopilots.length;
int stateLength = autopilots[0].state.length;
float[] result;
if(!distributed){
result = new float[vehicleCount * stateLength];
for(int i = 0; i<vehicleCount; i++){
for(int j = 0; j<stateLength; j++){
result[i*stateLength + j] = autopilots[i].state[j];
}
}
}
else{
result = new float[stateLength];
for(int i = 0; i<stateLength; i++){
result[i] = autopilots[activeVehicle].state[i];
}
}
return result;
}
//set the combined action from all vehicles
private void setAction(float[] action){
int vehicleCount = autopilots.length;
int actionLength = action.length / vehicleCount; //assume that every vehicle has same action space
if(!distributed){
for(int i = 0; i<vehicleCount; i++){
float[] result = new float[actionLength];
for(int j = 0; j<actionLength; j++){
//autopilots[i].action[j] = action[i * actionLength + j];
result[j] = action[i * actionLength + j];
}
autopilots[i].action = result;
}
}
else{
autopilots[activeVehicle].action = action;
}
return;
}
private boolean anyStateNull(){
for(int i = 0; i<autopilots.length; i++){
if(autopilots[i].state == null) return true;
}
return false;
}
//randomize start and end position for all vehicles
private void randomizeScenario(){
//get properties of all vehicles
VehicleProperties[] properties;
properties = config.cars.toArray(new VehicleProperties[0]);
Vec2[] startCoords = new Vec2[properties.length];
Vec2[] targetCoords = new Vec2[properties.length];
int x_boundary = (int) world.maxCorner.at(0);
int y_boundary = (int) world.maxCorner.at(1);
for(int i = 0; i<properties.length; i++){
startCoords[i] = new Vec2(rndGen.nextInt()%x_boundary,rndGen.nextInt()%y_boundary);
targetCoords[i] = new Vec2(rndGen.nextInt()%x_boundary,rndGen.nextInt()%y_boundary);
}
TaskProperties[] tasks = new TaskProperties[properties.length];
for(int i = 0; i<properties.length; i++){
tasks[i] = new TaskProperties();
tasks[i].addGoal(new PathGoalProperties()
.reach(targetCoords[i])
.withinRange(10)
.eventually());
}
System.out.println("Start position: " + startCoords[0].at(0) + " , " + startCoords[0].at(1));
System.out.println("Target position: " + targetCoords[0].at(0) + " , " + targetCoords[0].at(1));
for(int i = 0; i<properties.length; i++){
properties[i].task = tasks[i];
properties[i].start_pos = Optional.of(startCoords[i]);
properties[i].start_coords = Optional.empty();
properties[i].start_osm_node = Optional.empty();
}
}
public Simulator getSim(){
return simulator;
}
}
\ No newline at end of file
......@@ -38,7 +38,7 @@ public class RLSimulationInit {
public void init() { //TaskStatus replaced by void?
NodeConfiguration rosNodeConfiguration = NodeConfiguration.newPrivate();
NodeMain rlSimulationHandler = new RLSimulationHandler(config, simulationTime, world, pathfinding, map);
NodeMain rlSimulationHandler = new RLSimulationHandler(config, simulationTime, world, pathfinding, map, null);
NodeMainExecutor nodeMainExecuter = DefaultNodeMainExecutor.newDefault();
nodeMainExecuter.execute(rlSimulationHandler, rosNodeConfiguration);
......
package de.rwth.montisim.simulation.simulator.visualization.rl;
import de.rwth.montisim.commons.map.Pathfinding;
import de.rwth.montisim.commons.simulation.TaskStatus;
import de.rwth.montisim.commons.simulation.TimeUpdate;
import de.rwth.montisim.commons.utils.IPM;
import de.rwth.montisim.commons.utils.Vec2;
import de.rwth.montisim.commons.utils.json.SerializationException;
import de.rwth.montisim.simulation.environment.pathfinding.PathfindingImpl;
import de.rwth.montisim.simulation.environment.osmmap.*;
import de.rwth.montisim.simulation.environment.world.World;
import de.rwth.montisim.simulation.simulator.SimulationConfig;
import de.rwth.montisim.simulation.simulator.Simulator;
import de.rwth.montisim.simulation.simulator.visualization.car.CarRenderer;
import de.rwth.montisim.simulation.simulator.visualization.map.PathfinderRenderer;
import de.rwth.montisim.simulation.simulator.visualization.map.WorldRenderer;
import de.rwth.montisim.simulation.simulator.visualization.ui.Control;
import de.rwth.montisim.simulation.simulator.visualization.ui.SimulationRunner;
import de.rwth.montisim.simulation.simulator.visualization.ui.UIInfo;
import de.rwth.montisim.simulation.simulator.visualization.ui.Viewer2D;
import de.rwth.montisim.simulation.vehicle.Vehicle;
import de.rwth.montisim.simulation.simulator.RLSimulationHandler;
import de.rwth.montisim.simulation.simulator.Simulator;
import javax.swing.*;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.FlowLayout;
import java.awt.event.*;
import java.awt.BorderLayout;
import java.io.File;
import java.time.*;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.ros.exception.RosRuntimeException;
import org.ros.node.DefaultNodeMainExecutor;
import org.ros.node.NodeConfiguration;
import org.ros.node.NodeMain;
import org.ros.node.NodeMainExecutor;
public class RLVisualizer{
String current_scenario = "";
JLabel scenario_name;
Control control;
Viewer2D viewer;
private List<CarRenderer> carRenderers = new ArrayList<>();
Simulator simulator;
SimulationConfig simConfig;
World world;
OsmMap map;
Pathfinding pathfinding;
public Instant simTime;
RLSimulationHandler rlSimulationHandler;
private boolean done_clearing = true;
public RLVisualizer(World world, OsmMap map, Pathfinding pathfinding, SimulationConfig simConfig, Viewer2D viewer, Instant simTime){
this.world = world;
this.map = map;
//System.out.println(pathfinding == null);
this.pathfinding = pathfinding;
//System.out.println(this.pathfinding);
this.simConfig = simConfig;
this.viewer = viewer;
this.simTime = simTime;
}
public void init(){
//System.out.println(simConfig == null);
NodeConfiguration rosNodeConfiguration = NodeConfiguration.newPrivate();
//System.out.println(pathfinding == null);
rlSimulationHandler = new RLSimulationHandler(simConfig, simTime, world, pathfinding, map, this);
NodeMainExecutor nodeMainExecuter = DefaultNodeMainExecutor.newDefault();
nodeMainExecuter.execute(rlSimulationHandler, rosNodeConfiguration);
}
public void clearRenderer(){
done_clearing = false;
if(viewer != null && !carRenderers.isEmpty()){
viewer.clearRenderers();
carRenderers.clear();
}
done_clearing = true;
}
public void setup(){
// Setup visualizer
viewer.addRenderer(new WorldRenderer(world));
viewer.addRenderer(new PathfinderRenderer(pathfinding));
// Init CarRenderers and find view for all Vehicles
//Collection<Vehicle> vehicles = rlSimulationHandler.simulator.getVehicles();
Collection<Vehicle> vehicles = rlSimulationHandler.getSim().getVehicles();
setView(vehicles);
for (Vehicle v : vehicles) {
CarRenderer cr = new CarRenderer();
cr.setCar(v);
viewer.addRenderer(cr);
carRenderers.add(cr);
}
while(!done_clearing);
viewer.repaint();
}
void setView(Collection<Vehicle> vehicles) {
Vec2 avg_pos = new Vec2(0, 0);
int count = 0;
Vec2 min_pos = new Vec2(Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY);
Vec2 max_pos = new Vec2(Double.NEGATIVE_INFINITY, Double.NEGATIVE_INFINITY);
Vec2 vpos = new Vec2();
for (Vehicle v : vehicles) {
vpos.set(v.physicalObject.pos);
IPM.add(avg_pos, vpos);
if (vpos.x < min_pos.x)
min_pos.x = vpos.x;
if (vpos.y < min_pos.y)
min_pos.y = vpos.y;
if (vpos.x > max_pos.x)
max_pos.x = vpos.x;
if (vpos.y > max_pos.y)
max_pos.y = vpos.y;
++count;
}
if (count == 0) {
viewer.setCenter(avg_pos);
viewer.setZoom(4);
return;
}
IPM.multiply(avg_pos, 1.0 / (double) count);
viewer.setCenter(avg_pos);
Vec2 range = new Vec2();
IPM.subtractTo(range, max_pos, min_pos);
IPM.add(range, new Vec2(16, 16)); // Margin
Dimension d = viewer.getSize();
double xscale = d.getWidth() / range.x;
double yscale = d.getHeight() / range.y;
double scale = 20;
if (xscale < scale)
scale = xscale;
if (yscale < scale)
scale = yscale;
viewer.setZoom(scale);
}
public void redraw() {
for (CarRenderer cr : carRenderers)
cr.dirty = true;
while(!done_clearing);
viewer.update();
}
}
\ No newline at end of file
......@@ -87,13 +87,18 @@ public class Viewer2D extends JPanel implements MouseInputListener, MouseWheelLi
viewMatrix = computeViewMatrix();
invViewMatrix = computeInvViewMatrix();
}
for (Renderer r : renderers){
//Iterator<Renderer> iter = renderers.iterator();
List<Renderer> copy = new ArrayList<>(renderers);
for (Renderer r : copy){
//while(iter.hasNext()){
if(renderers.contains(r)){
if (dirty || r.dirty){
r.computeGeometry(viewMatrix);
r.dirty = false;
}
r.draw(g2);
}
else{return;}
}
if (dirty) dirty = false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment