gomog/internal/engine/projection.go

180 lines
3.9 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package engine
import (
"git.kingecg.top/kingecg/gomog/pkg/types"
)
// applyProjection 应用投影到文档数组
func applyProjection(docs []types.Document, projection types.Projection) []types.Document {
result := make([]types.Document, len(docs))
for i, doc := range docs {
projected := applyProjectionToDoc(doc.Data, projection)
// 处理 _id 投影
if includeID, ok := projection["_id"]; ok && !isTrueValue(includeID) {
// 排除 _id
} else {
projected["_id"] = doc.ID
}
result[i] = types.Document{
ID: doc.ID,
Data: projected,
}
}
return result
}
// applyProjectionToDoc 应用投影到单个文档
func applyProjectionToDoc(data map[string]interface{}, projection types.Projection) map[string]interface{} {
result := make(map[string]interface{})
// 检查是否是包含模式(所有值都是 1/true或排除模式所有值都是 0/false
isInclusionMode := false
hasInclusion := false
for field, value := range projection {
if field == "_id" {
continue
}
if isTrueValue(value) {
hasInclusion = true
}
}
// 如果有包含也有排除,优先使用包含模式
isInclusionMode = hasInclusion
for field, include := range projection {
if field == "_id" {
continue
}
if isInclusionMode && isTrueValue(include) {
// 包含模式:只包含指定字段
result[field] = getNestedValue(data, field)
// 处理 $elemMatch 投影
if elemMatchSpec, ok := include.(map[string]interface{}); ok {
if _, hasElemMatch := elemMatchSpec["$elemMatch"]; hasElemMatch {
result[field] = projectElemMatch(data, field, elemMatchSpec)
}
}
// 处理 $slice 投影
if sliceSpec, ok := include.(map[string]interface{}); ok {
if sliceVal, hasSlice := sliceSpec["$slice"]; hasSlice {
result[field] = projectSlice(data, field, sliceVal)
}
}
} else if !isInclusionMode && !isTrueValue(include) {
// 排除模式:排除指定字段
removeNestedValue(result, field)
}
}
// 如果是包含模式,复制所有指定字段
if isInclusionMode {
for field, include := range projection {
if field == "_id" {
continue
}
if isTrueValue(include) {
result[field] = getNestedValue(data, field)
}
}
}
return result
}
// projectElemMatch 投影数组中的匹配元素
func projectElemMatch(data map[string]interface{}, field string, spec map[string]interface{}) interface{} {
arr := getNestedValue(data, field)
if arr == nil {
return nil
}
array, ok := arr.([]interface{})
if !ok || len(array) == 0 {
return nil
}
// 获取 $elemMatch 条件
elemMatchSpec, ok := spec["$elemMatch"].(map[string]interface{})
if !ok {
return array[0] // 返回第一个元素
}
// 查找第一个匹配的元素
for _, item := range array {
if itemMap, ok := item.(map[string]interface{}); ok {
if MatchFilter(itemMap, elemMatchSpec) {
return item
}
}
}
return nil // 没有匹配的元素
}
// projectSlice 投影数组切片
func projectSlice(data map[string]interface{}, field string, sliceSpec interface{}) interface{} {
arr := getNestedValue(data, field)
if arr == nil {
return nil
}
array, ok := arr.([]interface{})
if !ok {
return arr
}
var skip int
var limit int
switch spec := sliceSpec.(type) {
case int:
// {$slice: 5} - 前 5 个
limit = spec
skip = 0
case float64:
limit = int(spec)
skip = 0
case []interface{}:
// {$slice: [10, 5]} - 跳过 10 个,取 5 个
if len(spec) >= 2 {
skip = int(toFloat64(spec[0]))
limit = int(toFloat64(spec[1]))
}
}
// 处理负数
if limit < 0 {
skip = len(array) + limit
if skip < 0 {
skip = 0
}
limit = -limit
}
// 应用跳过
if skip > 0 && skip < len(array) {
array = array[skip:]
} else if skip >= len(array) {
return []interface{}{}
}
// 应用限制
if limit >= 0 && limit < len(array) {
array = array[:limit]
} else if limit < 0 {
// 负数 limit 已经在上面处理过了
}
return array
}