545 lines
12 KiB
Go
545 lines
12 KiB
Go
package engine
|
|
|
|
import (
|
|
"fmt"
|
|
"math/rand"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.kingecg.top/kingecg/gomog/pkg/types"
|
|
)
|
|
|
|
// executeReplaceRoot 执行 $replaceRoot 阶段
|
|
func (e *AggregationEngine) executeReplaceRoot(spec interface{}, docs []types.Document) ([]types.Document, error) {
|
|
specMap, ok := spec.(map[string]interface{})
|
|
if !ok {
|
|
return docs, nil
|
|
}
|
|
|
|
newRootRaw, exists := specMap["newRoot"]
|
|
if !exists {
|
|
return docs, nil
|
|
}
|
|
|
|
var results []types.Document
|
|
for _, doc := range docs {
|
|
newRoot := e.evaluateExpression(doc.Data, newRootRaw)
|
|
if newRootMap, ok := newRoot.(map[string]interface{}); ok {
|
|
results = append(results, types.Document{
|
|
ID: doc.ID,
|
|
Data: newRootMap,
|
|
CreatedAt: doc.CreatedAt,
|
|
UpdatedAt: doc.UpdatedAt,
|
|
})
|
|
} else {
|
|
// 如果不是对象,创建包装文档
|
|
results = append(results, types.Document{
|
|
ID: doc.ID,
|
|
Data: map[string]interface{}{"value": newRoot},
|
|
CreatedAt: doc.CreatedAt,
|
|
UpdatedAt: doc.UpdatedAt,
|
|
})
|
|
}
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
// executeReplaceWith 执行 $replaceWith 阶段($replaceRoot 的别名)
|
|
func (e *AggregationEngine) executeReplaceWith(spec interface{}, docs []types.Document) ([]types.Document, error) {
|
|
// $replaceWith 是 $replaceRoot 的简写形式
|
|
// spec 本身就是 newRoot 表达式
|
|
var results []types.Document
|
|
for _, doc := range docs {
|
|
newRoot := e.evaluateExpression(doc.Data, spec)
|
|
if newRootMap, ok := newRoot.(map[string]interface{}); ok {
|
|
results = append(results, types.Document{
|
|
ID: doc.ID,
|
|
Data: newRootMap,
|
|
CreatedAt: doc.CreatedAt,
|
|
UpdatedAt: doc.UpdatedAt,
|
|
})
|
|
} else {
|
|
// 如果不是对象,创建包装文档
|
|
results = append(results, types.Document{
|
|
ID: doc.ID,
|
|
Data: map[string]interface{}{"value": newRoot},
|
|
CreatedAt: doc.CreatedAt,
|
|
UpdatedAt: doc.UpdatedAt,
|
|
})
|
|
}
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
// executeGraphLookup 执行 $graphLookup 阶段(递归查找)
|
|
func (e *AggregationEngine) executeGraphLookup(spec interface{}, docs []types.Document) ([]types.Document, error) {
|
|
specMap, ok := spec.(map[string]interface{})
|
|
if !ok {
|
|
return docs, nil
|
|
}
|
|
|
|
from, _ := specMap["from"].(string)
|
|
startWith := specMap["startWith"]
|
|
connectFromField, _ := specMap["connectFromField"].(string)
|
|
connectToField, _ := specMap["connectToField"].(string)
|
|
as, _ := specMap["as"].(string)
|
|
maxDepthRaw, _ := specMap["maxDepth"].(float64)
|
|
restrictSearchWithMatchRaw, _ := specMap["restrictSearchWithMatch"]
|
|
|
|
if as == "" || connectFromField == "" || connectToField == "" {
|
|
return docs, nil
|
|
}
|
|
|
|
maxDepth := int(maxDepthRaw)
|
|
if maxDepth == 0 {
|
|
maxDepth = -1 // 无限制
|
|
}
|
|
|
|
var results []types.Document
|
|
for _, doc := range docs {
|
|
// 计算起始值
|
|
startValue := e.evaluateExpression(doc.Data, startWith)
|
|
|
|
// 递归查找
|
|
connectedDocs := e.graphLookupRecursive(
|
|
from,
|
|
startValue,
|
|
connectFromField,
|
|
connectToField,
|
|
maxDepth,
|
|
restrictSearchWithMatchRaw,
|
|
make(map[string]bool),
|
|
)
|
|
|
|
// 添加结果数组
|
|
newDoc := make(map[string]interface{})
|
|
for k, v := range doc.Data {
|
|
newDoc[k] = v
|
|
}
|
|
newDoc[as] = connectedDocs
|
|
|
|
results = append(results, types.Document{
|
|
ID: doc.ID,
|
|
Data: newDoc,
|
|
CreatedAt: doc.CreatedAt,
|
|
UpdatedAt: doc.UpdatedAt,
|
|
})
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
// graphLookupRecursive 递归查找关联文档
|
|
func (e *AggregationEngine) graphLookupRecursive(
|
|
collection string,
|
|
startValue interface{},
|
|
connectFromField string,
|
|
connectToField string,
|
|
maxDepth int,
|
|
restrictSearchWithMatch interface{},
|
|
visited map[string]bool,
|
|
) []map[string]interface{} {
|
|
|
|
var results []map[string]interface{}
|
|
|
|
if maxDepth == 0 {
|
|
return results
|
|
}
|
|
|
|
// 获取目标集合
|
|
targetCollection := e.store.collections[collection]
|
|
if targetCollection == nil {
|
|
return results
|
|
}
|
|
|
|
// 查找匹配的文档
|
|
for docID, doc := range targetCollection.documents {
|
|
// 避免循环引用
|
|
if visited[docID] {
|
|
continue
|
|
}
|
|
|
|
// 检查是否匹配
|
|
docValue := getNestedValue(doc.Data, connectToField)
|
|
if !valuesEqual(startValue, docValue) {
|
|
continue
|
|
}
|
|
|
|
// 应用 restrictSearchWithMatch 过滤
|
|
if restrictSearchWithMatch != nil {
|
|
if matchSpec, ok := restrictSearchWithMatch.(map[string]interface{}); ok {
|
|
if !MatchFilter(doc.Data, matchSpec) {
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
|
|
// 标记为已访问
|
|
visited[docID] = true
|
|
|
|
// 添加到结果
|
|
docCopy := make(map[string]interface{})
|
|
for k, v := range doc.Data {
|
|
docCopy[k] = v
|
|
}
|
|
results = append(results, docCopy)
|
|
|
|
// 递归查找下一级
|
|
nextValue := getNestedValue(doc.Data, connectFromField)
|
|
moreResults := e.graphLookupRecursive(
|
|
collection,
|
|
nextValue,
|
|
connectFromField,
|
|
connectToField,
|
|
maxDepth-1,
|
|
restrictSearchWithMatch,
|
|
visited,
|
|
)
|
|
results = append(results, moreResults...)
|
|
}
|
|
|
|
return results
|
|
}
|
|
|
|
// executeSetWindowFields 执行 $setWindowFields 阶段(窗口函数)
|
|
func (e *AggregationEngine) executeSetWindowFields(spec interface{}, docs []types.Document) ([]types.Document, error) {
|
|
specMap, ok := spec.(map[string]interface{})
|
|
if !ok {
|
|
return docs, nil
|
|
}
|
|
|
|
outputsRaw, _ := specMap["output"].(map[string]interface{})
|
|
partitionByRaw, _ := specMap["partitionBy"]
|
|
sortByRaw, _ := specMap["sortBy"].(map[string]interface{})
|
|
|
|
if outputsRaw == nil {
|
|
return docs, nil
|
|
}
|
|
|
|
// 分组(分区)
|
|
partitions := make(map[string][]types.Document)
|
|
for _, doc := range docs {
|
|
var key string
|
|
if partitionByRaw != nil {
|
|
partitionKey := e.evaluateExpression(doc.Data, partitionByRaw)
|
|
key = fmt.Sprintf("%v", partitionKey)
|
|
} else {
|
|
key = "all"
|
|
}
|
|
partitions[key] = append(partitions[key], doc)
|
|
}
|
|
|
|
// 对每个分区排序
|
|
for key := range partitions {
|
|
if sortByRaw != nil && len(sortByRaw) > 0 {
|
|
sortDocsBySpec(partitions[key], sortByRaw)
|
|
}
|
|
}
|
|
|
|
// 应用窗口函数
|
|
var results []types.Document
|
|
for _, partition := range partitions {
|
|
for i, doc := range partition {
|
|
newDoc := make(map[string]interface{})
|
|
for k, v := range doc.Data {
|
|
newDoc[k] = v
|
|
}
|
|
|
|
// 计算每个输出字段
|
|
for fieldName, windowSpecRaw := range outputsRaw {
|
|
windowSpec, ok := windowSpecRaw.(map[string]interface{})
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
value := e.calculateWindowValue(windowSpec, partition, i, doc)
|
|
newDoc[fieldName] = value
|
|
}
|
|
|
|
results = append(results, types.Document{
|
|
ID: doc.ID,
|
|
Data: newDoc,
|
|
CreatedAt: doc.CreatedAt,
|
|
UpdatedAt: doc.UpdatedAt,
|
|
})
|
|
}
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
// calculateWindowValue 计算窗口函数值
|
|
func (e *AggregationEngine) calculateWindowValue(
|
|
windowSpec map[string]interface{},
|
|
partition []types.Document,
|
|
currentIndex int,
|
|
currentDoc types.Document,
|
|
) interface{} {
|
|
|
|
// 解析窗口操作符
|
|
for op, operand := range windowSpec {
|
|
switch op {
|
|
case "$documentNumber":
|
|
return float64(currentIndex + 1)
|
|
|
|
case "$rank":
|
|
return float64(currentIndex + 1)
|
|
|
|
case "$first":
|
|
expr := e.evaluateExpression(partition[0].Data, operand)
|
|
return expr
|
|
|
|
case "$last":
|
|
expr := e.evaluateExpression(partition[len(partition)-1].Data, operand)
|
|
return expr
|
|
|
|
case "$shift":
|
|
n := int(toFloat64(operand))
|
|
targetIndex := currentIndex + n
|
|
if targetIndex < 0 || targetIndex >= len(partition) {
|
|
return nil
|
|
}
|
|
return partition[targetIndex].Data
|
|
|
|
case "$fillDefault":
|
|
val := e.evaluateExpression(currentDoc.Data, operand)
|
|
if val == nil {
|
|
return 0 // 默认值
|
|
}
|
|
return val
|
|
|
|
case "$sum", "$avg", "$min", "$max":
|
|
// 聚合窗口函数
|
|
return e.aggregateWindow(op, operand, partition, currentIndex)
|
|
|
|
default:
|
|
// 普通表达式
|
|
return e.evaluateExpression(currentDoc.Data, windowSpec)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// aggregateWindow 聚合窗口函数
|
|
func (e *AggregationEngine) aggregateWindow(
|
|
op string,
|
|
operand interface{},
|
|
partition []types.Document,
|
|
currentIndex int,
|
|
) interface{} {
|
|
var values []float64
|
|
|
|
for i, doc := range partition {
|
|
// 根据窗口范围决定是否包含
|
|
windowSpec := getWindowRange(op, operand)
|
|
if !inWindow(i, currentIndex, windowSpec) {
|
|
continue
|
|
}
|
|
|
|
val := e.evaluateExpression(doc.Data, operand)
|
|
if num, ok := toNumber(val); ok {
|
|
values = append(values, num)
|
|
}
|
|
}
|
|
|
|
if len(values) == 0 {
|
|
return nil
|
|
}
|
|
|
|
switch op {
|
|
case "$sum":
|
|
sum := 0.0
|
|
for _, v := range values {
|
|
sum += v
|
|
}
|
|
return sum
|
|
|
|
case "$avg":
|
|
sum := 0.0
|
|
for _, v := range values {
|
|
sum += v
|
|
}
|
|
return sum / float64(len(values))
|
|
|
|
case "$min":
|
|
min := values[0]
|
|
for _, v := range values[1:] {
|
|
if v < min {
|
|
min = v
|
|
}
|
|
}
|
|
return min
|
|
|
|
case "$max":
|
|
max := values[0]
|
|
for _, v := range values[1:] {
|
|
if v > max {
|
|
max = v
|
|
}
|
|
}
|
|
return max
|
|
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// getWindowRange 获取窗口范围
|
|
func getWindowRange(op string, operand interface{}) map[string]interface{} {
|
|
// 简化实现:默认使用整个分区
|
|
return map[string]interface{}{"window": "all"}
|
|
}
|
|
|
|
// inWindow 检查索引是否在窗口内
|
|
func inWindow(index, current int, windowSpec map[string]interface{}) bool {
|
|
// 简化实现:包含所有索引
|
|
return true
|
|
}
|
|
|
|
// executeTextSearch 执行 $text 文本搜索
|
|
func (e *AggregationEngine) executeTextSearch(docs []types.Document, search string, language string, caseSensitive bool) ([]types.Document, error) {
|
|
var results []types.Document
|
|
|
|
// 分词搜索
|
|
searchTerms := strings.Fields(strings.ToLower(search))
|
|
|
|
for _, doc := range docs {
|
|
score := e.calculateTextScore(doc.Data, searchTerms, caseSensitive)
|
|
if score > 0 {
|
|
// 添加文本得分
|
|
newDoc := make(map[string]interface{})
|
|
for k, v := range doc.Data {
|
|
newDoc[k] = v
|
|
}
|
|
newDoc["_textScore"] = score
|
|
results = append(results, types.Document{
|
|
ID: doc.ID,
|
|
Data: newDoc,
|
|
CreatedAt: doc.CreatedAt,
|
|
UpdatedAt: doc.UpdatedAt,
|
|
})
|
|
}
|
|
}
|
|
|
|
// 按文本得分排序
|
|
sort.Slice(results, func(i, j int) bool {
|
|
scoreI := results[i].Data["_textScore"].(float64)
|
|
scoreJ := results[j].Data["_textScore"].(float64)
|
|
return scoreI > scoreJ
|
|
})
|
|
|
|
return results, nil
|
|
}
|
|
|
|
// calculateTextScore 计算文本匹配得分
|
|
func (e *AggregationEngine) calculateTextScore(doc map[string]interface{}, searchTerms []string, caseSensitive bool) float64 {
|
|
score := 0.0
|
|
|
|
// 递归搜索所有字符串字段
|
|
e.searchInValue(doc, searchTerms, caseSensitive, &score)
|
|
|
|
return score
|
|
}
|
|
|
|
// searchInValue 在值中搜索
|
|
func (e *AggregationEngine) searchInValue(value interface{}, searchTerms []string, caseSensitive bool, score *float64) {
|
|
switch v := value.(type) {
|
|
case string:
|
|
if !caseSensitive {
|
|
v = strings.ToLower(v)
|
|
}
|
|
for _, term := range searchTerms {
|
|
searchTerm := term
|
|
if !caseSensitive {
|
|
searchTerm = strings.ToLower(term)
|
|
}
|
|
if strings.Contains(v, searchTerm) {
|
|
*score += 1.0
|
|
}
|
|
}
|
|
|
|
case []interface{}:
|
|
for _, item := range v {
|
|
e.searchInValue(item, searchTerms, caseSensitive, score)
|
|
}
|
|
|
|
case map[string]interface{}:
|
|
for _, val := range v {
|
|
e.searchInValue(val, searchTerms, caseSensitive, score)
|
|
}
|
|
}
|
|
}
|
|
|
|
// sortDocsBySpec 根据规范对文档排序
|
|
func sortDocsBySpec(docs []types.Document, sortByRaw map[string]interface{}) {
|
|
type sortKeys struct {
|
|
doc types.Document
|
|
keys []float64
|
|
}
|
|
|
|
keys := make([]sortKeys, len(docs))
|
|
for i, doc := range docs {
|
|
var docKeys []float64
|
|
for _, fieldRaw := range sortByRaw {
|
|
field := getFieldValueStrFromDoc(doc, fieldRaw)
|
|
if num, ok := toNumber(field); ok {
|
|
docKeys = append(docKeys, num)
|
|
} else {
|
|
docKeys = append(docKeys, 0)
|
|
}
|
|
}
|
|
keys[i] = sortKeys{doc: doc, keys: docKeys}
|
|
}
|
|
|
|
sort.Slice(keys, func(i, j int) bool {
|
|
for k := range keys[i].keys {
|
|
if keys[i].keys[k] != keys[j].keys[k] {
|
|
return keys[i].keys[k] < keys[j].keys[k]
|
|
}
|
|
}
|
|
return false
|
|
})
|
|
|
|
for i, k := range keys {
|
|
docs[i] = k.doc
|
|
}
|
|
}
|
|
|
|
// getFieldValueStrFromDoc 从文档获取字段值
|
|
func getFieldValueStrFromDoc(doc types.Document, fieldRaw interface{}) interface{} {
|
|
if fieldStr, ok := fieldRaw.(string); ok {
|
|
return getNestedValue(doc.Data, fieldStr)
|
|
}
|
|
return fieldRaw
|
|
}
|
|
|
|
// valuesEqual 比较两个值是否相等
|
|
func valuesEqual(a, b interface{}) bool {
|
|
if a == nil && b == nil {
|
|
return true
|
|
}
|
|
if a == nil || b == nil {
|
|
return false
|
|
}
|
|
return fmt.Sprintf("%v", a) == fmt.Sprintf("%v", b)
|
|
}
|
|
|
|
// getRandomDocuments 随机获取指定数量的文档
|
|
func getRandomDocuments(docs []types.Document, n int) []types.Document {
|
|
if n >= len(docs) {
|
|
return docs
|
|
}
|
|
|
|
// 随机打乱
|
|
rand.Seed(time.Now().UnixNano())
|
|
rand.Shuffle(len(docs), func(i, j int) {
|
|
docs[i], docs[j] = docs[j], docs[i]
|
|
})
|
|
|
|
return docs[:n]
|
|
}
|