package nl.andrewl.railsignalapi.service; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import nl.andrewl.railsignalapi.dao.BranchRepository; import nl.andrewl.railsignalapi.dao.RailSystemRepository; import nl.andrewl.railsignalapi.dao.SignalBranchConnectionRepository; import nl.andrewl.railsignalapi.dao.SignalRepository; import nl.andrewl.railsignalapi.model.*; import nl.andrewl.railsignalapi.rest.dto.SignalConnectionsUpdatePayload; import nl.andrewl.railsignalapi.rest.dto.SignalCreationPayload; import nl.andrewl.railsignalapi.rest.dto.SignalResponse; import nl.andrewl.railsignalapi.websocket.BranchUpdateMessage; import nl.andrewl.railsignalapi.websocket.SignalUpdateMessage; import nl.andrewl.railsignalapi.websocket.SignalUpdateType; import org.springframework.http.HttpStatus; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import org.springframework.web.server.ResponseStatusException; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketSession; import java.io.IOException; import java.util.*; import java.util.concurrent.ConcurrentHashMap; @Service @RequiredArgsConstructor @Slf4j public class SignalService { private final RailSystemRepository railSystemRepository; private final SignalRepository signalRepository; private final BranchRepository branchRepository; private final SignalBranchConnectionRepository signalBranchConnectionRepository; private final ObjectMapper mapper = new ObjectMapper(); private final Map> signalWebSocketSessions = new ConcurrentHashMap<>(); @Transactional public SignalResponse createSignal(long rsId, SignalCreationPayload payload) { var rs = railSystemRepository.findById(rsId) .orElseThrow(() -> new ResponseStatusException(HttpStatus.NOT_FOUND, "Rail system not found.")); if (signalRepository.existsByNameAndRailSystem(payload.name(), rs)) { throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Signal " + payload.name() + " already exists."); } if (payload.branchConnections().size() != 2) { throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Exactly two branch connections must be provided."); } // Ensure that the directions of the connections are opposite each other. Direction dir1 = Direction.parse(payload.branchConnections().get(0).direction()); Direction dir2 = Direction.parse(payload.branchConnections().get(1).direction()); if (!dir1.isOpposite(dir2)) throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Branch connections must be opposite each other."); Set branchConnections = new HashSet<>(); Signal signal = new Signal(rs, payload.name(), payload.position(), branchConnections); for (var branchData : payload.branchConnections()) { Branch branch; if (branchData.id() != null) { branch = this.branchRepository.findByIdAndRailSystem(branchData.id(), rs) .orElseThrow(() -> new ResponseStatusException(HttpStatus.BAD_REQUEST, "Branch id " + branchData.id() + " is invalid.")); } else { branch = this.branchRepository.findByNameAndRailSystem(branchData.name(), rs) .orElse(new Branch(rs, branchData.name(), BranchStatus.FREE)); } Direction dir = Direction.parse(branchData.direction()); branchConnections.add(new SignalBranchConnection(signal, branch, dir)); } signal = signalRepository.save(signal); return new SignalResponse(signal); } @Transactional public SignalResponse updateSignalBranchConnections(long rsId, long sigId, SignalConnectionsUpdatePayload payload) { var signal = signalRepository.findByIdAndRailSystemId(sigId, rsId) .orElseThrow(() -> new ResponseStatusException(HttpStatus.NOT_FOUND)); for (var c : payload.connections()) { var fromConnection = signalBranchConnectionRepository.findById(c.from()) .orElseThrow(() -> new ResponseStatusException(HttpStatus.BAD_REQUEST, "Could not find signal branch connection: " + c.from())); if (!fromConnection.getSignal().getId().equals(signal.getId())) { throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Can only update signal branch connections originating from the specified signal."); } var toConnection = signalBranchConnectionRepository.findById(c.to()) .orElseThrow(() -> new ResponseStatusException(HttpStatus.BAD_REQUEST, "Could not find signal branch connection: " + c.to())); if (!fromConnection.getBranch().getId().equals(toConnection.getBranch().getId())) { throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Signal branch connections can only path via a mutual branch."); } fromConnection.getReachableSignalConnections().add(toConnection); signalBranchConnectionRepository.save(fromConnection); } for (var con : signal.getBranchConnections()) { Set connectionsToRemove = new HashSet<>(); for (var reachableCon : con.getReachableSignalConnections()) { if (!payload.connections().contains(new SignalConnectionsUpdatePayload.ConnectionData(con.getId(), reachableCon.getId()))) { connectionsToRemove.add(reachableCon); } } con.getReachableSignalConnections().removeAll(connectionsToRemove); signalBranchConnectionRepository.save(con); } // Reload the signal. signal = signalRepository.findById(signal.getId()).orElseThrow(); return new SignalResponse(signal); } @Transactional public void registerSignalWebSocketSession(Set signalIds, WebSocketSession session) { this.signalWebSocketSessions.put(session, signalIds); // Instantly send a data packet so that the signals are up-to-date. RailSystem rs = null; for (var signalId : signalIds) { var signal = signalRepository.findById(signalId) .orElseThrow(() -> new ResponseStatusException(HttpStatus.BAD_REQUEST, "Invalid signal id.")); if (rs == null) { rs = signal.getRailSystem(); } else if (!rs.getId().equals(signal.getRailSystem().getId())) { throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Cannot open signal websocket session for signals from different rail systems."); } for (var branchConnection : signal.getBranchConnections()) { try { session.sendMessage(new TextMessage(mapper.writeValueAsString( new BranchUpdateMessage( branchConnection.getBranch().getId(), branchConnection.getBranch().getStatus().name() ) ))); } catch (IOException e) { e.printStackTrace(); } } signal.setOnline(true); signalRepository.save(signal); } } @Transactional public void deregisterSignalWebSocketSession(WebSocketSession session) { var ids = this.signalWebSocketSessions.remove(session); if (ids != null) { for (var signalId : ids) { signalRepository.findById(signalId).ifPresent(signal -> { signal.setOnline(false); signalRepository.save(signal); }); } } } public WebSocketSession getSignalWebSocketSession(long signalId) { for (var entry : signalWebSocketSessions.entrySet()) { if (entry.getValue().contains(signalId)) return entry.getKey(); } return null; } @Transactional public void handleSignalUpdate(SignalUpdateMessage updateMessage) { var signal = signalRepository.findById(updateMessage.signalId()) .orElseThrow(() -> new ResponseStatusException(HttpStatus.NOT_FOUND)); Branch fromBranch = null; Branch toBranch = null; for (var con : signal.getBranchConnections()) { if (con.getBranch().getId() == updateMessage.fromBranchId()) { fromBranch = con.getBranch(); } if (con.getBranch().getId() == updateMessage.toBranchId()) { toBranch = con.getBranch(); } } if (fromBranch == null || toBranch == null) { throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Invalid branches."); } SignalUpdateType updateType = SignalUpdateType.valueOf(updateMessage.type().trim().toUpperCase()); if (updateType == SignalUpdateType.BEGIN && toBranch.getStatus() != BranchStatus.FREE) { log.warn("Warning! Train is entering a non-free branch {}.", toBranch.getName()); } if (toBranch.getStatus() != BranchStatus.OCCUPIED) { log.info("Updating branch {} status from {} to {}.", toBranch.getName(), toBranch.getStatus(), BranchStatus.OCCUPIED); toBranch.setStatus(BranchStatus.OCCUPIED); branchRepository.save(toBranch); broadcastToConnectedSignals(toBranch); } if (updateType == SignalUpdateType.END) { if (fromBranch.getStatus() != BranchStatus.FREE) { log.info("Updating branch {} status from {} to {}.", fromBranch.getName(), fromBranch.getStatus(), BranchStatus.FREE); fromBranch.setStatus(BranchStatus.FREE); branchRepository.save(fromBranch); broadcastToConnectedSignals(fromBranch); } } else if (updateType == SignalUpdateType.BEGIN) { if (fromBranch.getStatus() != BranchStatus.OCCUPIED) { log.info("Updating branch {} status from {} to {}.", fromBranch.getName(), fromBranch.getStatus(), BranchStatus.OCCUPIED); fromBranch.setStatus(BranchStatus.OCCUPIED); branchRepository.save(fromBranch); broadcastToConnectedSignals(fromBranch); } } } private void broadcastToConnectedSignals(Branch branch) { try { WebSocketMessage msg = new TextMessage(mapper.writeValueAsString( new BranchUpdateMessage(branch.getId(), branch.getStatus().name()) )); signalRepository.findAllConnectedToBranch(branch).stream() .map(s -> getSignalWebSocketSession(s.getId())) .filter(Objects::nonNull).distinct() .forEach(session -> { try { session.sendMessage(msg); } catch (IOException e) { e.printStackTrace(); } }); } catch (JsonProcessingException e) { e.printStackTrace(); } } @Transactional(readOnly = true) public SignalResponse getSignal(long rsId, long sigId) { var s = signalRepository.findByIdAndRailSystemId(sigId, rsId) .orElseThrow(() -> new ResponseStatusException(HttpStatus.NOT_FOUND)); return new SignalResponse(s); } @Transactional(readOnly = true) public List getAllSignals(long rsId) { var rs = railSystemRepository.findById(rsId) .orElseThrow(() -> new ResponseStatusException(HttpStatus.NOT_FOUND, "Rail system not found.")); return signalRepository.findAllByRailSystemOrderByName(rs).stream() .map(SignalResponse::new) .toList(); } @Transactional public void deleteSignal(long rsId, long sigId) { var s = signalRepository.findByIdAndRailSystemId(sigId, rsId) .orElseThrow(() -> new ResponseStatusException(HttpStatus.NOT_FOUND)); signalRepository.delete(s); } }