diff --git a/backend/onetime/seed-data/main.go b/backend/onetime/seed-data/main.go index f93eef40..dba42c82 100644 --- a/backend/onetime/seed-data/main.go +++ b/backend/onetime/seed-data/main.go @@ -29,15 +29,18 @@ func main() { if err := godotenv.Load(".env"); err != nil { panic(err) } - db, err := dbhandler.Open(context.TODO(), options.Client().ApplyURI(os.Getenv("MONGODB_URI"))) + ctx, cancel := context.WithCancelCause(context.Background()) + defer cancel(nil) + db, err := dbhandler.Open(ctx, options.Client().ApplyURI(os.Getenv("MONGODB_URI"))) if err != nil { + cancel(err) return } - defer db.Close() + defer db.Close(ctx) data := &Seed{} b, _ := os.ReadFile("./data/nt-tokyo.yaml") if err := yaml.Unmarshal(b, data); err != nil { - panic(err) + cancel(err) } //for _, stop := range data.StopRails { @@ -63,11 +66,12 @@ func main() { //} for _, block := range data.Blocks { println(block) - err := db.AddBlock(&statev1.BlockState{ + err := db.AddBlock(ctx, &statev1.BlockState{ BlockId: string(block), State: statev1.BlockStateEnum_BLOCK_STATE_OPEN, }) if err != nil { + cancel(err) return } } diff --git a/backend/state-manager/pkg/connect/connect_handler.go b/backend/state-manager/pkg/connect/connect_handler.go index cfc00599..63e5ca26 100644 --- a/backend/state-manager/pkg/connect/connect_handler.go +++ b/backend/state-manager/pkg/connect/connect_handler.go @@ -27,7 +27,7 @@ func (s *StateManagerServer) GetBlockStates( ctx context.Context, req *connect.Request[statev1.GetBlockStatesRequest], ) (*connect.Response[statev1.GetBlockStatesResponse], error) { - blockStates, err := s.DBHandler.GetBlocks() + blockStates, err := s.DBHandler.GetBlocks(ctx) if err != nil { err = connect.NewError( connect.CodeUnknown, @@ -57,7 +57,7 @@ func (s *StateManagerServer) UpdateBlockState( ctx context.Context, req *connect.Request[statev1.UpdateBlockStateRequest], ) (*connect.Response[statev1.UpdateBlockStateResponse], error) { - err := s.DBHandler.UpdateBlock(req.Msg.State) + err := s.DBHandler.UpdateBlock(ctx, req.Msg.State) if err != nil { err = connect.NewError( connect.CodeUnknown, @@ -77,7 +77,7 @@ func (s *StateManagerServer) UpdatePointState( ctx context.Context, req *connect.Request[statev1.UpdatePointStateRequest], ) (*connect.Response[statev1.UpdatePointStateResponse], error) { - err := s.DBHandler.UpdatePoint(req.Msg.State) + err := s.DBHandler.UpdatePoint(ctx, req.Msg.State) if err != nil { err = connect.NewError( connect.CodeUnknown, @@ -86,7 +86,7 @@ func (s *StateManagerServer) UpdatePointState( slog.Default().Error("db error", err) return nil, err } - s.MqttHandler.NotifyStateUpdate("point", req.Msg.State.Id, req.Msg.State.State.String()) + s.MqttHandler.NotifyStateUpdate(ctx, "point", req.Msg.State.Id, req.Msg.State.State.String()) return connect.NewResponse(&statev1.UpdatePointStateResponse{}), nil } @@ -110,7 +110,7 @@ func (s *StateManagerServer) UpdateStopState( ctx context.Context, req *connect.Request[statev1.UpdateStopStateRequest], ) (*connect.Response[statev1.UpdateStopStateResponse], error) { - err := s.DBHandler.UpdateStop(req.Msg.State) + err := s.DBHandler.UpdateStop(ctx,req.Msg.State) if err != nil { err = connect.NewError( connect.CodeUnknown, @@ -119,7 +119,7 @@ func (s *StateManagerServer) UpdateStopState( slog.Default().Error("db connection error", err) return nil, err } - s.MqttHandler.NotifyStateUpdate("stop", req.Msg.State.Id, req.Msg.State.State.String()) + s.MqttHandler.NotifyStateUpdate(ctx, "stop", req.Msg.State.Id, req.Msg.State.State.String()) return connect.NewResponse(&statev1.UpdateStopStateResponse{}), nil } diff --git a/backend/state-manager/pkg/db/db.go b/backend/state-manager/pkg/db/db.go index 80903aab..4276956f 100644 --- a/backend/state-manager/pkg/db/db.go +++ b/backend/state-manager/pkg/db/db.go @@ -7,7 +7,6 @@ package db import ( "context" "fmt" - "log" "log/slog" statev1 "github.com/ueckoken/plarail2023/backend/spec/state/v1" @@ -40,12 +39,11 @@ func Open(ctx context.Context, opts *options.ClientOptions) (*DBHandler, error) }, nil } -func (db *DBHandler) Close() { +func (db *DBHandler) Close(ctx context.Context) { slog.Default().Debug("Closing connection to DB...") - // TODO: contextを受けて、その子contextをDBクライアントに渡す - if err := db.stateManagerDB.Client().Disconnect(context.TODO()); err != nil { + if err := db.stateManagerDB.Client().Disconnect(ctx); err != nil { slog.Default().Error("DB Connection Closing failed") - log.Println(err) + return } slog.Default().Debug("DB Connection is successfully closed") } @@ -54,148 +52,145 @@ func (db *DBHandler) Close() { Point */ -func (db *DBHandler) UpdatePoint(PointAndState *statev1.PointAndState) error { +func (db *DBHandler) UpdatePoint(ctx context.Context, PointAndState *statev1.PointAndState) error { collection := db.stateManagerDB.Collection("points") _, err := collection.UpdateOne( - context.Background(), + ctx, bson.M{"id": PointAndState.Id}, bson.M{"$set": bson.M{"state": PointAndState.State}}, ) - if err != nil { - return err - } - return nil + return fmt.Errorf("update point failed `%w`", err) } -func (db *DBHandler) AddPoint(PointAndState *statev1.PointAndState) error { +func (db *DBHandler) AddPoint(ctx context.Context, PointAndState *statev1.PointAndState) error { collection := db.stateManagerDB.Collection("points") - _, err := collection.InsertOne(context.Background(), PointAndState) + _, err := collection.InsertOne(ctx, PointAndState) if err != nil { - return err + return fmt.Errorf("insert point failed `%w`", err) } return nil } -func (db *DBHandler) GetPoint(pointId string) (*statev1.PointAndState, error) { +func (db *DBHandler) GetPoint(ctx context.Context, pointId string) (*statev1.PointAndState, error) { collection := db.stateManagerDB.Collection("points") var result *statev1.PointAndState - err := collection.FindOne(context.Background(), bson.M{"id": pointId}).Decode(&result) + err := collection.FindOne(ctx, bson.M{"id": pointId}).Decode(&result) if err != nil { - return nil, err + return nil, fmt.Errorf("get point failed `%w`", err) } return result, nil } -func (db *DBHandler) GetPoints() []*statev1.PointAndState { +func (db *DBHandler) GetPoints(ctx context.Context) ([]*statev1.PointAndState, error) { collection := db.stateManagerDB.Collection("points") - cursor, err := collection.Find(context.Background(), bson.M{}) + cursor, err := collection.Find(ctx, bson.M{}) if err != nil { slog.Default().Warn("Get Points failed", slog.Any("err", err)) - panic(err) + return nil, fmt.Errorf("get points failed `%w`", err) } var result []*statev1.PointAndState - if err = cursor.All(context.Background(), &result); err != nil { - panic(err) + if err := cursor.All(ctx, &result); err != nil { + slog.Default().Warn("Get Points failed", slog.Any("err", err)) + return nil, fmt.Errorf("get points failed `%w`", err) } - return result + return result, nil } /* Stop */ -func (db *DBHandler) UpdateStop(stop *statev1.StopAndState) error { +func (db *DBHandler) UpdateStop(ctx context.Context, stop *statev1.StopAndState) error { collection := db.stateManagerDB.Collection("stops") - _, err := collection.UpdateOne( - context.Background(), + ctx, bson.M{"id": stop.Id}, bson.M{"$set": bson.M{"state": stop.State}}, ) if err != nil { - return err + return fmt.Errorf("update stop failed `%w`", err) } return nil } -func (db *DBHandler) AddStop(stop *statev1.StopAndState) error { +func (db *DBHandler) AddStop(ctx context.Context, stop *statev1.StopAndState) error { collection := db.stateManagerDB.Collection("stops") - _, err := collection.InsertOne(context.Background(), stop) + _, err := collection.InsertOne(ctx, stop) if err != nil { - return err + return fmt.Errorf("insert stop failed `%w`", err) } return nil } -func (db *DBHandler) GetStop(stopId string) (*statev1.StopAndState, error) { +func (db *DBHandler) GetStop(ctx context.Context, stopId string) (*statev1.StopAndState, error) { collection := db.stateManagerDB.Collection("stops") var result *statev1.StopAndState - err := collection.FindOne(context.Background(), bson.M{"id": stopId}).Decode(&result) + err := collection.FindOne(ctx, bson.M{"id": stopId}).Decode(&result) if err != nil { - return nil, err + return nil, fmt.Errorf("get stop failed `%w`", err) } return result, nil } -func (db *DBHandler) GetStops() []*statev1.StopAndState { +func (db *DBHandler) GetStops(ctx context.Context) ([]*statev1.StopAndState, error) { collection := db.stateManagerDB.Collection("stops") - cursor, err := collection.Find(context.Background(), bson.M{}) + cursor, err := collection.Find(ctx, bson.M{}) if err != nil { - panic(err) + return nil, fmt.Errorf("get stops failed `%w`", err) } var result []*statev1.StopAndState - if err = cursor.All(context.Background(), &result); err != nil { - panic(err) + if err := cursor.All(ctx, &result); err != nil { + return nil, fmt.Errorf("get stops failed `%w`", err) } - return result + return result, nil } /* Block */ -func (db *DBHandler) AddBlock(block *statev1.BlockState) error { +func (db *DBHandler) AddBlock(ctx context.Context, block *statev1.BlockState) error { collection := db.stateManagerDB.Collection("blocks") - _, err := collection.InsertOne(context.Background(), block) + _, err := collection.InsertOne(ctx, block) if err != nil { - return err + return fmt.Errorf("insert block failed `%w`", err) } return nil } -func (db *DBHandler) UpdateBlock(block *statev1.BlockState) error { +func (db *DBHandler) UpdateBlock(ctx context.Context, block *statev1.BlockState) error { collection := db.stateManagerDB.Collection("blocks") _, err := collection.UpdateOne( - context.Background(), + ctx, bson.M{"blockid": block.BlockId}, bson.M{"$set": bson.M{"state": block.State}}, ) if err != nil { - return err + return fmt.Errorf("update block failed `%w`", err) } return nil } -func (db *DBHandler) GetBlock(blockId string) (*statev1.BlockState, error) { +func (db *DBHandler) GetBlock(ctx context.Context, blockId string) (*statev1.BlockState, error) { collection := db.stateManagerDB.Collection("blocks") var result *statev1.BlockState - err := collection.FindOne(context.Background(), bson.M{"blockid": blockId}).Decode(&result) + err := collection.FindOne(ctx, bson.M{"blockid": blockId}).Decode(&result) if err != nil { - return nil, err + return nil, fmt.Errorf("get block failed `%w`", err) } return result, nil } -func (db *DBHandler) GetBlocks() ([]*statev1.BlockState, error) { +func (db *DBHandler) GetBlocks(ctx context.Context) ([]*statev1.BlockState, error) { collection := db.stateManagerDB.Collection("blocks") - cursor, err := collection.Find(context.Background(), bson.M{}) + cursor, err := collection.Find(ctx, bson.M{}) if err != nil { - return nil, err + return nil, fmt.Errorf("get blocks failed `%w`", err) } var result []*statev1.BlockState - if err = cursor.All(context.Background(), &result); err != nil { - return nil, err + if err = cursor.All(ctx, &result); err != nil { + return nil, fmt.Errorf("get blocks failed `%w`", err) } return result, nil } diff --git a/backend/state-manager/pkg/mqtt_handler/mqtt_handler.go b/backend/state-manager/pkg/mqtt_handler/mqtt_handler.go index bccd6bff..2ae89e97 100644 --- a/backend/state-manager/pkg/mqtt_handler/mqtt_handler.go +++ b/backend/state-manager/pkg/mqtt_handler/mqtt_handler.go @@ -41,7 +41,7 @@ func (h *Handler) Start(ctx context.Context) error { case msg := <-msgCh: // if topic start with "point/" log.Printf("Received message: %s from topic: %s\n", msg.Payload(), msg.Topic()) - h.topicHandler(msg) + h.topicHandler(ctx, msg) case <-ctx.Done(): slog.Default().Info("Interrupted at mqtt_handler") h.client.Disconnect(1000) @@ -77,7 +77,7 @@ func (h *Handler) Send(topic string, payload string) { {target}/{pointId}/update */ -func (h *Handler) topicHandler(msg mqtt.Message) { +func (h *Handler) topicHandler(ctx context.Context, msg mqtt.Message) { // Handle by Path arr := strings.Split(msg.Topic(), "/") target := arr[0] @@ -90,32 +90,37 @@ func (h *Handler) topicHandler(msg mqtt.Message) { switch method { case "get": - h.getState(target, id) + h.getState(ctx, target, id) case "delta": - h.getDelta(target, id) + h.getDelta(ctx, target, id) case "update": - h.updateState(target, id, msg.Payload()) + h.updateState(ctx, target, id, msg.Payload()) } } -func (h *Handler) NotifyStateUpdate(target string, id string, state string) { +func (h *Handler) NotifyStateUpdate(ctx context.Context, target string, id string, state string) { token := h.client.Publish(target+"/"+id+"/delta", 0, false, state) - token.Wait() + select { + case <-token.Done(): + slog.Default().Info("token done in mqtt_handler.NotifyStateUpdate", slog.Any("err", token.Error())) + case <-ctx.Done(): + slog.Default().Info("context done in mqtt_handler.NotifyStateUpdate", slog.Any("err", ctx.Err())) + } } -func (h *Handler) getState(target string, id string) { +func (h *Handler) getState(ctx context.Context, target string, id string) error { switch target { case "point": - point, err := h.dbHandler.GetPoint(id) + point, err := h.dbHandler.GetPoint(ctx, id) if err != nil { - log.Fatal(err) + slog.Default().Info("db error in mqtt_handler.getState", slog.Any("err", err)) } log.Println(point) token := h.client.Publish("point/"+id+"/get/accepted", 0, false, point.State.String()) token.Wait() case "stop": - stop, err := h.dbHandler.GetStop(id) + stop, err := h.dbHandler.GetStop(ctx, id) if err != nil { log.Fatal(err) } @@ -124,7 +129,7 @@ func (h *Handler) getState(target string, id string) { token.Wait() case "block": - block, err := h.dbHandler.GetBlock(id) + block, err := h.dbHandler.GetBlock(ctx, id) if err != nil { log.Fatal(err) } @@ -160,13 +165,14 @@ func (h *Handler) getState(target string, id string) { case "train": // TODO: implement } + return nil } -func (h *Handler) getDelta(target string, id string) { +func (h *Handler) getDelta(ctx context.Context, target string, id string) { } -func (h *Handler) updateState(target string, id string, payload []byte) { +func (h *Handler) updateState(ctx context.Context, target string, id string, payload []byte) { switch target { case "block": @@ -175,7 +181,7 @@ func (h *Handler) updateState(target string, id string, payload []byte) { fmt.Print("newState: ") fmt.Println(newState) if newState == "OPEN" { - err := h.dbHandler.UpdateBlock(&statev1.BlockState{ + err := h.dbHandler.UpdateBlock(ctx, &statev1.BlockState{ BlockId: id, State: statev1.BlockStateEnum_BLOCK_STATE_OPEN, }) @@ -184,25 +190,25 @@ func (h *Handler) updateState(target string, id string, payload []byte) { } // NT Tokyo if id == "yamashita_b1" { - err := h.dbHandler.UpdateStop(&statev1.StopAndState{ + err := h.dbHandler.UpdateStop(ctx, &statev1.StopAndState{ Id: "yamashita_s1", State: statev1.StopStateEnum_STOP_STATE_GO, }) if err != nil { log.Fatal(err) } - h.NotifyStateUpdate("stop", "yamashita_s1", statev1.StopStateEnum_STOP_STATE_GO.String()) - err = h.dbHandler.UpdateStop(&statev1.StopAndState{ + h.NotifyStateUpdate(ctx, "stop", "yamashita_s1", statev1.StopStateEnum_STOP_STATE_GO.String()) + err = h.dbHandler.UpdateStop(ctx, &statev1.StopAndState{ Id: "yamashita_s2", State: statev1.StopStateEnum_STOP_STATE_GO, }) if err != nil { log.Fatal(err) } - h.NotifyStateUpdate("stop", "yamashita_s2", statev1.StopStateEnum_STOP_STATE_GO.String()) + h.NotifyStateUpdate(ctx, "stop", "yamashita_s2", statev1.StopStateEnum_STOP_STATE_GO.String()) // 今と逆にする - now, err := h.dbHandler.GetPoint("yamashita_p1") + now, err := h.dbHandler.GetPoint(ctx, "yamashita_p1") if err != nil { log.Fatal(err) } @@ -212,7 +218,7 @@ func (h *Handler) updateState(target string, id string, payload []byte) { } else { newS = statev1.PointStateEnum_POINT_STATE_NORMAL } - err = h.dbHandler.UpdatePoint(&statev1.PointAndState{ + err = h.dbHandler.UpdatePoint(ctx, &statev1.PointAndState{ Id: "yamashita_p1", State: newS, }) @@ -221,11 +227,11 @@ func (h *Handler) updateState(target string, id string, payload []byte) { log.Fatal(err) } - h.NotifyStateUpdate("point", "yamashita_p1", newS.String()) + h.NotifyStateUpdate(ctx, "point", "yamashita_p1", newS.String()) } } else if newState == "CLOSE" { - err := h.dbHandler.UpdateBlock(&statev1.BlockState{ + err := h.dbHandler.UpdateBlock(ctx, &statev1.BlockState{ BlockId: id, State: statev1.BlockStateEnum_BLOCK_STATE_CLOSE, }) @@ -234,22 +240,22 @@ func (h *Handler) updateState(target string, id string, payload []byte) { } // NT Tokyo if id == "yamashita_b1" { - err := h.dbHandler.UpdateStop(&statev1.StopAndState{ + err := h.dbHandler.UpdateStop(ctx, &statev1.StopAndState{ Id: "yamashita_s1", State: statev1.StopStateEnum_STOP_STATE_STOP, }) if err != nil { log.Fatal(err) } - h.NotifyStateUpdate("stop", "yamashita_s1", statev1.StopStateEnum_STOP_STATE_STOP.String()) - err = h.dbHandler.UpdateStop(&statev1.StopAndState{ + h.NotifyStateUpdate(ctx, "stop", "yamashita_s1", statev1.StopStateEnum_STOP_STATE_STOP.String()) + err = h.dbHandler.UpdateStop(ctx, &statev1.StopAndState{ Id: "yamashita_s2", State: statev1.StopStateEnum_STOP_STATE_STOP, }) if err != nil { log.Fatal(err) } - h.NotifyStateUpdate("stop", "yamashita_s2", statev1.StopStateEnum_STOP_STATE_STOP.String()) + h.NotifyStateUpdate(ctx, "stop", "yamashita_s2", statev1.StopStateEnum_STOP_STATE_STOP.String()) } }