package tcp import ( "bufio" "bytes" "encoding/binary" "encoding/json" "io" "log" "net" "sync" "git.kingecg.top/kingecg/gomog/internal/engine" "git.kingecg.top/kingecg/gomog/pkg/types" ) // 操作码常量 const ( OP_REPLY = 1 OP_UPDATE = 4 OP_INSERT = 8 OP_QUERY = 2004 OP_GETMORE = 2006 OP_DELETE = 2007 OP_MSG = 2013 ) // MessageHeader 消息头(16 字节) type MessageHeader struct { Length uint32 // 消息总长度 RequestID uint32 // 请求 ID ResponseTo uint32 // 响应到的请求 ID OpCode uint32 // 操作码 } // TCPServer TCP 服务器 type TCPServer struct { listener net.Listener handler *MessageHandler wg sync.WaitGroup done chan struct{} } // MessageHandler 消息处理器 type MessageHandler struct { store *engine.MemoryStore crud *engine.CRUDHandler agg *engine.AggregationEngine } // NewMessageHandler 创建消息处理器 func NewMessageHandler(store *engine.MemoryStore, crud *engine.CRUDHandler, agg *engine.AggregationEngine) *MessageHandler { return &MessageHandler{ store: store, crud: crud, agg: agg, } } // NewTCPServer 创建 TCP 服务器 func NewTCPServer(addr string, handler *MessageHandler) (*TCPServer, error) { ln, err := net.Listen("tcp", addr) if err != nil { return nil, err } return &TCPServer{ listener: ln, handler: handler, done: make(chan struct{}), }, nil } // Start 启动服务器 func (s *TCPServer) Start() error { go s.acceptLoop() return nil } // acceptLoop 接受连接循环 func (s *TCPServer) acceptLoop() { for { select { case <-s.done: return default: conn, err := s.listener.Accept() if err != nil { select { case <-s.done: return default: log.Printf("Accept error: %v", err) continue } } s.wg.Add(1) go s.handleConnection(conn) } } } // handleConnection 处理连接 func (s *TCPServer) handleConnection(conn net.Conn) { defer s.wg.Done() defer conn.Close() reader := bufio.NewReader(conn) for { select { case <-s.done: return default: // 读取消息头 header, err := readHeader(reader) if err != nil { if err != io.EOF { log.Printf("Read header error: %v", err) } return } // 读取消息体 bodySize := header.Length - 16 body := make([]byte, bodySize) if _, err := io.ReadFull(reader, body); err != nil { log.Printf("Read body error: %v", err) return } // 处理消息 response, err := s.handler.HandleMessage(header.OpCode, body, header.RequestID) if err != nil { sendErrorResponse(conn, header.RequestID, err) continue } // 发送响应 if err := writeResponse(conn, header.RequestID, response); err != nil { log.Printf("Write response error: %v", err) return } } } } // readHeader 读取消息头 func readHeader(r *bufio.Reader) (*MessageHeader, error) { header := &MessageHeader{} // 读取 16 字节消息头 buf := make([]byte, 16) if _, err := io.ReadFull(r, buf); err != nil { return nil, err } reader := bytes.NewReader(buf) // 小端序读取 binary.Read(reader, binary.LittleEndian, &header.Length) binary.Read(reader, binary.LittleEndian, &header.RequestID) binary.Read(reader, binary.LittleEndian, &header.ResponseTo) binary.Read(reader, binary.LittleEndian, &header.OpCode) return header, nil } // writeResponse 写入响应 func writeResponse(conn net.Conn, requestID uint32, response interface{}) error { // 序列化响应 data, err := json.Marshal(response) if err != nil { return err } // 构建响应消息 msgLength := uint32(16 + len(data)) header := &MessageHeader{ Length: msgLength, RequestID: 0, // 服务器生成的请求 ID ResponseTo: requestID, OpCode: OP_REPLY, } // 写入消息头 buf := new(bytes.Buffer) binary.Write(buf, binary.LittleEndian, header.Length) binary.Write(buf, binary.LittleEndian, header.RequestID) binary.Write(buf, binary.LittleEndian, header.ResponseTo) binary.Write(buf, binary.LittleEndian, header.OpCode) // 写入消息体 buf.Write(data) _, err = conn.Write(buf.Bytes()) return err } // sendErrorResponse 发送错误响应 func sendErrorResponse(conn net.Conn, requestID uint32, err error) { response := map[string]interface{}{ "ok": 0, "errmsg": err.Error(), } writeResponse(conn, requestID, response) } // HandleMessage 处理消息 func (h *MessageHandler) HandleMessage(opCode uint32, body []byte, requestID uint32) (interface{}, error) { switch opCode { case OP_INSERT: return h.handleInsert(body) case OP_QUERY: return h.handleQuery(body) case OP_UPDATE: return h.handleUpdate(body) case OP_DELETE: return h.handleDelete(body) case OP_MSG: return h.handleMsg(body) default: return nil, ErrUnknownOpCode } } // handleInsert 处理插入消息 func (h *MessageHandler) handleInsert(body []byte) (interface{}, error) { var req struct { Collection string `json:"collection"` Documents []map[string]interface{} `json:"documents"` Ordered bool `json:"ordered"` BypassValidation bool `json:"bypassDocumentValidation"` } if err := json.Unmarshal(body, &req); err != nil { return nil, err } // 执行插入 result, err := h.crud.Insert(nil, req.Collection, req.Documents) if err != nil { return nil, err } return map[string]interface{}{ "ok": 1, "n": result.N, "insertedIds": result.InsertedIDs, }, nil } // handleQuery 处理查询消息 func (h *MessageHandler) handleQuery(body []byte) (interface{}, error) { var req struct { Collection string `json:"collection"` Filter types.Filter `json:"filter"` Projection types.Projection `json:"projection"` Sort types.Sort `json:"sort"` Skip int `json:"skip"` Limit int `json:"limit"` } if err := json.Unmarshal(body, &req); err != nil { return nil, err } // 执行查询 docs, err := h.store.Find(req.Collection, req.Filter) if err != nil { return nil, err } // 应用限制和跳过 if req.Skip > 0 && req.Skip < len(docs) { docs = docs[req.Skip:] } if req.Limit > 0 && req.Limit < len(docs) { docs = docs[:req.Limit] } return map[string]interface{}{ "ok": 1, "cursor": map[string]interface{}{ "firstBatch": docs, "id": 0, "ns": req.Collection, }, }, nil } // handleUpdate 处理更新消息 func (h *MessageHandler) handleUpdate(body []byte) (interface{}, error) { var req struct { Collection string `json:"collection"` Updates []types.UpdateOperation `json:"updates"` Ordered bool `json:"ordered"` } if err := json.Unmarshal(body, &req); err != nil { return nil, err } totalMatched := 0 totalModified := 0 for _, op := range req.Updates { matched, modified, _, err := h.store.Update(req.Collection, op.Q, op.U, op.Upsert, op.ArrayFilters) if err != nil { return nil, err } totalMatched += matched totalModified += modified } return map[string]interface{}{ "ok": 1, "n": totalMatched, "nModified": totalModified, }, nil } // handleDelete 处理删除消息 func (h *MessageHandler) handleDelete(body []byte) (interface{}, error) { var req struct { Collection string `json:"collection"` Deletes []types.DeleteOperation `json:"deletes"` Ordered bool `json:"ordered"` } if err := json.Unmarshal(body, &req); err != nil { return nil, err } totalDeleted := 0 for _, op := range req.Deletes { deleted, err := h.store.Delete(req.Collection, op.Q) if err != nil { return nil, err } totalDeleted += deleted if op.Limit == 1 && deleted > 0 { break } } return map[string]interface{}{ "ok": 1, "n": totalDeleted, "deletedCount": totalDeleted, }, nil } // handleMsg 处理 OP_MSG 消息(MongoDB 3.6+ 通用消息格式) func (h *MessageHandler) handleMsg(body []byte) (interface{}, error) { // 解析 OP_MSG 格式 // 简化实现:假设 body 是 JSON 格式的通用请求 var req struct { Operation string `json:"operation"` Collection string `json:"collection"` Params map[string]interface{} `json:"params"` } if err := json.Unmarshal(body, &req); err != nil { return nil, err } switch req.Operation { case "find": return h.handleFindMsg(req.Collection, req.Params) case "insert": return h.handleInsertMsg(req.Collection, req.Params) case "update": return h.handleUpdateMsg(req.Collection, req.Params) case "delete": return h.handleDeleteMsg(req.Collection, req.Params) case "aggregate": return h.handleAggregateMsg(req.Collection, req.Params) default: return nil, ErrUnknownOperation } } // handleFindMsg 处理 find 消息 func (h *MessageHandler) handleFindMsg(collection string, params map[string]interface{}) (interface{}, error) { filter, _ := params["filter"].(types.Filter) docs, err := h.store.Find(collection, filter) if err != nil { return nil, err } return map[string]interface{}{ "ok": 1, "cursor": map[string]interface{}{ "firstBatch": docs, "id": 0, "ns": collection, }, }, nil } // handleInsertMsg 处理 insert 消息 func (h *MessageHandler) handleInsertMsg(collection string, params map[string]interface{}) (interface{}, error) { documents, ok := params["documents"].([]map[string]interface{}) if !ok { return nil, ErrInvalidDocuments } result, err := h.crud.Insert(nil, collection, documents) if err != nil { return nil, err } return map[string]interface{}{ "ok": 1, "n": result.N, "insertedIds": result.InsertedIDs, }, nil } // handleUpdateMsg 处理 update 消息 func (h *MessageHandler) handleUpdateMsg(collection string, params map[string]interface{}) (interface{}, error) { updatesRaw, ok := params["updates"].([]interface{}) if !ok { return nil, ErrInvalidUpdates } // 转换 updates updates := make([]types.UpdateOperation, 0, len(updatesRaw)) for _, u := range updatesRaw { if updateMap, ok := u.(map[string]interface{}); ok { q, _ := updateMap["q"].(types.Filter) uData, _ := updateMap["u"].(types.Update) upsert, _ := updateMap["upsert"].(bool) multi, _ := updateMap["multi"].(bool) updates = append(updates, types.UpdateOperation{ Q: q, U: uData, Upsert: upsert, Multi: multi, }) } } totalMatched := 0 totalModified := 0 for _, op := range updates { matched, modified, _, err := h.store.Update(collection, op.Q, op.U, op.Upsert, op.ArrayFilters) if err != nil { return nil, err } totalMatched += matched totalModified += modified } return map[string]interface{}{ "ok": 1, "n": totalMatched, "nModified": totalModified, }, nil } // handleDeleteMsg 处理 delete 消息 func (h *MessageHandler) handleDeleteMsg(collection string, params map[string]interface{}) (interface{}, error) { deletesRaw, ok := params["deletes"].([]interface{}) if !ok { return nil, ErrInvalidDeletes } deletes := make([]types.DeleteOperation, 0, len(deletesRaw)) for _, d := range deletesRaw { if deleteMap, ok := d.(map[string]interface{}); ok { q, _ := deleteMap["q"].(types.Filter) limit := 0 if l, ok := deleteMap["limit"].(float64); ok { limit = int(l) } deletes = append(deletes, types.DeleteOperation{ Q: q, Limit: limit, }) } } totalDeleted := 0 for _, op := range deletes { deleted, err := h.store.Delete(collection, op.Q) if err != nil { return nil, err } totalDeleted += deleted if op.Limit == 1 && deleted > 0 { break } } return map[string]interface{}{ "ok": 1, "n": totalDeleted, "deletedCount": totalDeleted, }, nil } // handleAggregateMsg 处理 aggregate 消息 func (h *MessageHandler) handleAggregateMsg(collection string, params map[string]interface{}) (interface{}, error) { pipelineRaw, ok := params["pipeline"].([]interface{}) if !ok { return nil, ErrInvalidPipeline } // 转换 pipeline pipeline := make([]types.AggregateStage, 0, len(pipelineRaw)) for _, stage := range pipelineRaw { if stageMap, ok := stage.(map[string]interface{}); ok { for stageName, spec := range stageMap { pipeline = append(pipeline, types.AggregateStage{ Stage: stageName, Spec: spec, }) break } } } results, err := h.agg.Execute(collection, pipeline) if err != nil { return nil, err } return map[string]interface{}{ "ok": 1, "result": results, }, nil } // Stop 停止服务器 func (s *TCPServer) Stop() error { close(s.done) s.listener.Close() s.wg.Wait() return nil }