369 lines
9.1 KiB
Go
369 lines
9.1 KiB
Go
package engine
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
|
||
"git.kingecg.top/kingecg/gomog/pkg/errors"
|
||
"git.kingecg.top/kingecg/gomog/pkg/types"
|
||
)
|
||
|
||
// 特殊红黑标记(用于 $redact)
|
||
const (
|
||
RedactDescend = "$$DESCEND"
|
||
RedactPrune = "$$PRUNE"
|
||
RedactKeep = "$$KEEP"
|
||
)
|
||
|
||
// executeUnionWith 执行 $unionWith 阶段
|
||
func (e *AggregationEngine) executeUnionWith(spec interface{}, docs []types.Document) ([]types.Document, error) {
|
||
var collection string
|
||
var pipelineStages []types.AggregateStage
|
||
|
||
// 解析 spec:支持字符串和对象两种形式
|
||
switch s := spec.(type) {
|
||
case string:
|
||
// 简写形式:{ $unionWith: "collection" }
|
||
collection = s
|
||
pipelineStages = []types.AggregateStage{}
|
||
|
||
case map[string]interface{}:
|
||
// 完整形式:{ $unionWith: { coll: "...", pipeline: [...] } }
|
||
coll, ok := s["coll"].(string)
|
||
if !ok {
|
||
return docs, nil
|
||
}
|
||
collection = coll
|
||
|
||
// 解析 pipeline
|
||
pipelineRaw, _ := s["pipeline"].([]interface{})
|
||
for _, stageRaw := range pipelineRaw {
|
||
stageMap, ok := stageRaw.(map[string]interface{})
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
for stageName, stageSpec := range stageMap {
|
||
pipelineStages = append(pipelineStages, types.AggregateStage{
|
||
Stage: stageName,
|
||
Spec: stageSpec,
|
||
})
|
||
break
|
||
}
|
||
}
|
||
|
||
default:
|
||
return docs, nil
|
||
}
|
||
|
||
// 获取并集集合的所有文档
|
||
unionDocs, err := e.store.GetAllDocuments(collection)
|
||
if err != nil {
|
||
// 集合不存在返回空数组
|
||
unionDocs = []types.Document{}
|
||
}
|
||
|
||
// 如果指定了 pipeline,对并集数据执行 pipeline
|
||
if len(pipelineStages) > 0 {
|
||
unionDocs, err = e.ExecutePipeline(unionDocs, pipelineStages)
|
||
if err != nil {
|
||
return nil, errors.Wrap(err, errors.ErrAggregationError, "failed to execute union pipeline")
|
||
}
|
||
}
|
||
|
||
// 合并原文档和并集文档
|
||
result := make([]types.Document, 0, len(docs)+len(unionDocs))
|
||
result = append(result, docs...)
|
||
result = append(result, unionDocs...)
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// executeRedact 执行 $redact 阶段
|
||
func (e *AggregationEngine) executeRedact(spec interface{}, docs []types.Document) ([]types.Document, error) {
|
||
var results []types.Document
|
||
|
||
for _, doc := range docs {
|
||
redactedData, keep := e.redactDocument(doc.Data, spec)
|
||
|
||
if keep {
|
||
results = append(results, types.Document{
|
||
ID: doc.ID,
|
||
Data: redactedData.(map[string]interface{}),
|
||
})
|
||
}
|
||
}
|
||
|
||
return results, nil
|
||
}
|
||
|
||
// redactDocument 递归处理文档的红黑
|
||
func (e *AggregationEngine) redactDocument(data interface{}, spec interface{}) (interface{}, bool) {
|
||
// 评估红黑表达式
|
||
dataMap, ok := data.(map[string]interface{})
|
||
if !ok {
|
||
return data, true
|
||
}
|
||
|
||
result := e.evaluateExpression(dataMap, spec)
|
||
|
||
// 根据结果决定行为
|
||
switch result {
|
||
case RedactKeep:
|
||
return data, true
|
||
case RedactPrune:
|
||
return nil, false
|
||
case RedactDescend:
|
||
// 继续处理嵌套结构
|
||
return e.redactNested(data, spec)
|
||
default:
|
||
// 默认继续 descend
|
||
return e.redactNested(data, spec)
|
||
}
|
||
}
|
||
|
||
// redactNested 递归处理嵌套文档和数组
|
||
func (e *AggregationEngine) redactNested(data interface{}, spec interface{}) (interface{}, bool) {
|
||
switch d := data.(type) {
|
||
case map[string]interface{}:
|
||
return e.redactMap(d, spec)
|
||
case []interface{}:
|
||
return e.redactArray(d, spec)
|
||
default:
|
||
return data, true
|
||
}
|
||
}
|
||
|
||
func (e *AggregationEngine) redactMap(m map[string]interface{}, spec interface{}) (map[string]interface{}, bool) {
|
||
result := make(map[string]interface{})
|
||
|
||
for k, v := range m {
|
||
fieldResult, keep := e.redactDocument(v, spec)
|
||
|
||
if keep {
|
||
result[k] = fieldResult
|
||
}
|
||
}
|
||
|
||
return result, true
|
||
}
|
||
|
||
func (e *AggregationEngine) redactArray(arr []interface{}, spec interface{}) ([]interface{}, bool) {
|
||
result := make([]interface{}, 0)
|
||
|
||
for _, item := range arr {
|
||
itemResult, keep := e.redactDocument(item, spec)
|
||
if keep {
|
||
result = append(result, itemResult)
|
||
}
|
||
}
|
||
|
||
return result, true
|
||
}
|
||
|
||
// executeOut 执行 $out 阶段
|
||
func (e *AggregationEngine) executeOut(spec interface{}, docs []types.Document, currentCollection string) ([]types.Document, error) {
|
||
var targetCollection string
|
||
|
||
// 解析 spec:支持字符串和对象两种形式
|
||
switch s := spec.(type) {
|
||
case string:
|
||
targetCollection = s
|
||
|
||
case map[string]interface{}:
|
||
// 支持 { db: "...", coll: "..." } 形式
|
||
if db, ok := s["db"].(string); ok && db != "" {
|
||
targetCollection = db + "." + s["coll"].(string)
|
||
} else {
|
||
targetCollection = s["coll"].(string)
|
||
}
|
||
|
||
default:
|
||
return nil, errors.New(errors.ErrInvalidRequest, "invalid $out specification")
|
||
}
|
||
|
||
// 删除目标集合的现有数据(如果有)
|
||
err := e.store.DropCollection(targetCollection)
|
||
if err != nil && err != errors.ErrCollectionNotFnd {
|
||
return nil, errors.Wrap(err, errors.ErrDatabaseError, "failed to drop target collection")
|
||
}
|
||
|
||
// 创建新集合并插入所有文档
|
||
for _, doc := range docs {
|
||
err := e.store.InsertDocument(targetCollection, doc)
|
||
if err != nil {
|
||
return nil, errors.Wrap(err, errors.ErrDatabaseError, "failed to insert document")
|
||
}
|
||
}
|
||
|
||
// 返回确认文档
|
||
return []types.Document{{
|
||
Data: map[string]interface{}{
|
||
"ok": float64(1),
|
||
"nInserted": float64(len(docs)),
|
||
"targetCollection": targetCollection,
|
||
},
|
||
}}, nil
|
||
}
|
||
|
||
// executeMerge 执行 $merge 阶段
|
||
func (e *AggregationEngine) executeMerge(spec interface{}, docs []types.Document, currentCollection string) ([]types.Document, error) {
|
||
// 解析 spec
|
||
mergeSpec, ok := spec.(map[string]interface{})
|
||
if !ok {
|
||
return nil, errors.New(errors.ErrInvalidRequest, "invalid $merge specification")
|
||
}
|
||
|
||
// 获取目标集合名
|
||
var targetCollection string
|
||
switch into := mergeSpec["into"].(type) {
|
||
case string:
|
||
targetCollection = into
|
||
case map[string]interface{}:
|
||
targetCollection = into["coll"].(string)
|
||
default:
|
||
return nil, errors.New(errors.ErrInvalidRequest, "invalid $merge into specification")
|
||
}
|
||
|
||
// 获取匹配字段(默认 _id)
|
||
onField, _ := mergeSpec["on"].(string)
|
||
if onField == "" {
|
||
onField = "_id"
|
||
}
|
||
|
||
// 获取匹配策略
|
||
whenMatched, _ := mergeSpec["whenMatched"].(string)
|
||
if whenMatched == "" {
|
||
whenMatched = "replace"
|
||
}
|
||
|
||
whenNotMatched, _ := mergeSpec["whenNotMatched"].(string)
|
||
if whenNotMatched == "" {
|
||
whenNotMatched = "insert"
|
||
}
|
||
|
||
// 获取目标集合现有文档
|
||
existingDocs, _ := e.store.GetAllDocuments(targetCollection)
|
||
existingMap := make(map[string]types.Document)
|
||
for _, doc := range existingDocs {
|
||
key := getDocumentKey(doc, onField)
|
||
existingMap[key] = doc
|
||
}
|
||
|
||
// 统计信息
|
||
stats := map[string]float64{
|
||
"nInserted": 0,
|
||
"nUpdated": 0,
|
||
"nUnchanged": 0,
|
||
"nDeleted": 0,
|
||
}
|
||
|
||
// 处理每个输入文档
|
||
for _, doc := range docs {
|
||
key := getDocumentKey(doc, onField)
|
||
_, exists := existingMap[key]
|
||
|
||
if exists {
|
||
// 文档已存在
|
||
switch whenMatched {
|
||
case "replace":
|
||
e.store.UpdateDocument(targetCollection, doc)
|
||
stats["nUpdated"]++
|
||
|
||
case "keepExisting":
|
||
stats["nUnchanged"]++
|
||
|
||
case "merge":
|
||
// 合并字段
|
||
if existing, ok := existingMap[key]; ok {
|
||
mergedData := deepCopyMap(existing.Data)
|
||
for k, v := range doc.Data {
|
||
mergedData[k] = v
|
||
}
|
||
doc.Data = mergedData
|
||
e.store.UpdateDocument(targetCollection, doc)
|
||
stats["nUpdated"]++
|
||
}
|
||
|
||
case "fail":
|
||
return nil, errors.New(errors.ErrDuplicateKey, "document already exists")
|
||
|
||
case "delete":
|
||
// 删除已存在的文档
|
||
stats["nDeleted"]++
|
||
}
|
||
} else {
|
||
// 文档不存在
|
||
if whenNotMatched == "insert" {
|
||
e.store.InsertDocument(targetCollection, doc)
|
||
stats["nInserted"]++
|
||
}
|
||
}
|
||
}
|
||
|
||
// 返回统计信息
|
||
return []types.Document{{
|
||
Data: map[string]interface{}{
|
||
"ok": float64(1),
|
||
"nInserted": stats["nInserted"],
|
||
"nUpdated": stats["nUpdated"],
|
||
"nUnchanged": stats["nUnchanged"],
|
||
"nDeleted": stats["nDeleted"],
|
||
},
|
||
}}, nil
|
||
}
|
||
|
||
// getDocumentKey 获取文档的唯一键
|
||
func getDocumentKey(doc types.Document, keyField string) string {
|
||
if keyField == "_id" {
|
||
return doc.ID
|
||
}
|
||
|
||
value := getNestedValue(doc.Data, keyField)
|
||
if value == nil {
|
||
return ""
|
||
}
|
||
|
||
return fmt.Sprintf("%v", value)
|
||
}
|
||
|
||
// executeIndexStats 执行 $indexStats 阶段(简化版本)
|
||
func (e *AggregationEngine) executeIndexStats(spec interface{}, docs []types.Document) ([]types.Document, error) {
|
||
// 返回模拟的索引统计信息
|
||
return []types.Document{{
|
||
Data: map[string]interface{}{
|
||
"name": "id_idx",
|
||
"key": map[string]interface{}{"_id": 1},
|
||
"accesses": map[string]interface{}{
|
||
"ops": float64(0),
|
||
"since": "2024-01-01T00:00:00Z",
|
||
},
|
||
},
|
||
}}, nil
|
||
}
|
||
|
||
// executeCollStats 执行 $collStats 阶段(简化版本)
|
||
func (e *AggregationEngine) executeCollStats(spec interface{}, docs []types.Document) ([]types.Document, error) {
|
||
// 返回集合统计信息
|
||
return []types.Document{{
|
||
Data: map[string]interface{}{
|
||
"ns": "test.collection",
|
||
"count": float64(len(docs)),
|
||
"size": estimateSize(docs),
|
||
"storageSize": float64(0), // 内存存储无此概念
|
||
"nindexes": float64(1),
|
||
},
|
||
}}, nil
|
||
}
|
||
|
||
// estimateSize 估算文档大小(字节)
|
||
func estimateSize(docs []types.Document) float64 {
|
||
total := 0
|
||
for _, doc := range docs {
|
||
// JSON 序列化后的大小
|
||
data, _ := json.Marshal(doc.Data)
|
||
total += len(data)
|
||
}
|
||
return float64(total)
|
||
}
|