mirror of https://github.com/usememos/memos.git
refactor: memo filter
This commit is contained in:
parent
1a75d19a89
commit
ed23cbc011
|
|
@ -23,6 +23,14 @@ func NewCommonSQLConverter(dialect SQLDialect) *CommonSQLConverter {
|
|||
}
|
||||
}
|
||||
|
||||
// NewCommonSQLConverterWithOffset creates a new converter with the specified dialect and parameter offset.
|
||||
func NewCommonSQLConverterWithOffset(dialect SQLDialect, offset int) *CommonSQLConverter {
|
||||
return &CommonSQLConverter{
|
||||
dialect: dialect,
|
||||
paramIndex: offset + 1,
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertExprToSQL converts a CEL expression to SQL using the configured dialect.
|
||||
func (c *CommonSQLConverter) ConvertExprToSQL(ctx *ConvertContext, expr *exprv1.Expr) error {
|
||||
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||
|
|
@ -114,7 +122,7 @@ func (c *CommonSQLConverter) handleComparisonOperator(ctx *ConvertContext, callE
|
|||
return err
|
||||
}
|
||||
|
||||
if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list"}, identifier) {
|
||||
if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) {
|
||||
return errors.Errorf("invalid identifier for %s", callExpr.Function)
|
||||
}
|
||||
|
||||
|
|
@ -132,7 +140,7 @@ func (c *CommonSQLConverter) handleComparisonOperator(ctx *ConvertContext, callE
|
|||
return c.handleStringComparison(ctx, identifier, operator, value)
|
||||
case "creator_id":
|
||||
return c.handleIntComparison(ctx, identifier, operator, value)
|
||||
case "has_task_list":
|
||||
case "has_task_list", "has_link", "has_code", "has_incomplete_tasks":
|
||||
return c.handleBooleanComparison(ctx, identifier, operator, value)
|
||||
}
|
||||
|
||||
|
|
@ -226,15 +234,18 @@ func (c *CommonSQLConverter) handleElementInTags(ctx *ConvertContext, elementExp
|
|||
}
|
||||
|
||||
// Use dialect-specific JSON contains logic
|
||||
sqlExpr := c.dialect.GetJSONContains("$.tags", "element")
|
||||
template := c.dialect.GetJSONContains("$.tags", "element")
|
||||
sqlExpr := strings.Replace(template, "?", c.dialect.GetParameterPlaceholder(c.paramIndex), 1)
|
||||
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// For SQLite, we need a different approach since it uses LIKE
|
||||
// Handle args based on dialect
|
||||
if _, ok := c.dialect.(*SQLiteDialect); ok {
|
||||
// SQLite uses LIKE with pattern
|
||||
ctx.Args = append(ctx.Args, fmt.Sprintf(`%%"%s"%%`, element))
|
||||
} else {
|
||||
// MySQL and PostgreSQL expect plain values
|
||||
ctx.Args = append(ctx.Args, element)
|
||||
}
|
||||
c.paramIndex++
|
||||
|
|
@ -251,7 +262,10 @@ func (c *CommonSQLConverter) handleTagInList(ctx *ConvertContext, values []any)
|
|||
subconditions = append(subconditions, c.dialect.GetJSONLike("$.tags", "pattern"))
|
||||
args = append(args, fmt.Sprintf(`%%"%s"%%`, v))
|
||||
} else {
|
||||
subconditions = append(subconditions, c.dialect.GetJSONContains("$.tags", "element"))
|
||||
// Replace ? with proper placeholder for each dialect
|
||||
template := c.dialect.GetJSONContains("$.tags", "element")
|
||||
sql := strings.Replace(template, "?", c.dialect.GetParameterPlaceholder(c.paramIndex), 1)
|
||||
subconditions = append(subconditions, sql)
|
||||
args = append(args, v)
|
||||
}
|
||||
c.paramIndex++
|
||||
|
|
@ -279,8 +293,14 @@ func (c *CommonSQLConverter) handleVisibilityInList(ctx *ConvertContext, values
|
|||
}
|
||||
|
||||
tablePrefix := c.dialect.GetTablePrefix()
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`visibility` IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
|
||||
return err
|
||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.visibility IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`visibility` IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, values...)
|
||||
|
|
@ -307,8 +327,16 @@ func (c *CommonSQLConverter) handleContainsOperator(ctx *ConvertContext, callExp
|
|||
}
|
||||
|
||||
tablePrefix := c.dialect.GetTablePrefix()
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`content` LIKE %s", tablePrefix, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
|
||||
// PostgreSQL uses ILIKE and no backticks
|
||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.content ILIKE %s", tablePrefix, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`content` LIKE %s", tablePrefix, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg))
|
||||
|
|
@ -320,19 +348,37 @@ func (c *CommonSQLConverter) handleContainsOperator(ctx *ConvertContext, callExp
|
|||
func (c *CommonSQLConverter) handleIdentifier(ctx *ConvertContext, identExpr *exprv1.Expr_Ident) error {
|
||||
identifier := identExpr.GetName()
|
||||
|
||||
if !slices.Contains([]string{"pinned", "has_task_list"}, identifier) {
|
||||
if !slices.Contains([]string{"pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) {
|
||||
return errors.Errorf("invalid identifier %s", identifier)
|
||||
}
|
||||
|
||||
if identifier == "pinned" {
|
||||
tablePrefix := c.dialect.GetTablePrefix()
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`pinned` IS TRUE", tablePrefix)); err != nil {
|
||||
return err
|
||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.pinned IS TRUE", tablePrefix)); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`pinned` IS TRUE", tablePrefix)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else if identifier == "has_task_list" {
|
||||
if _, err := ctx.Buffer.WriteString(c.dialect.GetBooleanCheck("$.property.hasTaskList")); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if identifier == "has_link" {
|
||||
if _, err := ctx.Buffer.WriteString(c.dialect.GetBooleanCheck("$.property.hasLink")); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if identifier == "has_code" {
|
||||
if _, err := ctx.Buffer.WriteString(c.dialect.GetBooleanCheck("$.property.hasCode")); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if identifier == "has_incomplete_tasks" {
|
||||
if _, err := ctx.Buffer.WriteString(c.dialect.GetBooleanCheck("$.property.hasIncompleteTasks")); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
@ -366,15 +412,23 @@ func (c *CommonSQLConverter) handleStringComparison(ctx *ConvertContext, field,
|
|||
}
|
||||
|
||||
tablePrefix := c.dialect.GetTablePrefix()
|
||||
fieldName := field
|
||||
if field == "visibility" {
|
||||
fieldName = "`visibility`"
|
||||
} else if field == "content" {
|
||||
fieldName = "`content`"
|
||||
}
|
||||
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, fieldName, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||
// PostgreSQL doesn't use backticks
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// MySQL and SQLite use backticks
|
||||
fieldName := field
|
||||
if field == "visibility" {
|
||||
fieldName = "`visibility`"
|
||||
} else if field == "content" {
|
||||
fieldName = "`content`"
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, fieldName, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, valueStr)
|
||||
|
|
@ -394,8 +448,17 @@ func (c *CommonSQLConverter) handleIntComparison(ctx *ConvertContext, field, ope
|
|||
}
|
||||
|
||||
tablePrefix := c.dialect.GetTablePrefix()
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`%s` %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
|
||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||
// PostgreSQL doesn't use backticks
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// MySQL and SQLite use backticks
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`%s` %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
|
|
@ -411,18 +474,121 @@ func (c *CommonSQLConverter) handleBooleanComparison(ctx *ConvertContext, field,
|
|||
|
||||
valueBool, ok := value.(bool)
|
||||
if !ok {
|
||||
return errors.New("invalid boolean value for has_task_list")
|
||||
return errors.Errorf("invalid boolean value for %s", field)
|
||||
}
|
||||
|
||||
sqlExpr := c.dialect.GetBooleanComparison("$.property.hasTaskList", valueBool)
|
||||
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
|
||||
return err
|
||||
// Map field name to JSON path
|
||||
var jsonPath string
|
||||
switch field {
|
||||
case "has_task_list":
|
||||
jsonPath = "$.property.hasTaskList"
|
||||
case "has_link":
|
||||
jsonPath = "$.property.hasLink"
|
||||
case "has_code":
|
||||
jsonPath = "$.property.hasCode"
|
||||
case "has_incomplete_tasks":
|
||||
jsonPath = "$.property.hasIncompleteTasks"
|
||||
}
|
||||
|
||||
// For dialects that need parameters (PostgreSQL)
|
||||
// Special handling for SQLite based on field
|
||||
if _, ok := c.dialect.(*SQLiteDialect); ok {
|
||||
if field == "has_task_list" {
|
||||
// has_task_list uses = 1 / = 0 / != 1 / != 0
|
||||
var sqlExpr string
|
||||
if operator == "=" {
|
||||
if valueBool {
|
||||
sqlExpr = fmt.Sprintf("%s = 1", c.dialect.GetJSONExtract(jsonPath))
|
||||
} else {
|
||||
sqlExpr = fmt.Sprintf("%s = 0", c.dialect.GetJSONExtract(jsonPath))
|
||||
}
|
||||
} else { // operator == "!="
|
||||
if valueBool {
|
||||
sqlExpr = fmt.Sprintf("%s != 1", c.dialect.GetJSONExtract(jsonPath))
|
||||
} else {
|
||||
sqlExpr = fmt.Sprintf("%s != 0", c.dialect.GetJSONExtract(jsonPath))
|
||||
}
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
// Other fields use IS TRUE / NOT(... IS TRUE)
|
||||
var sqlExpr string
|
||||
if operator == "=" {
|
||||
if valueBool {
|
||||
sqlExpr = fmt.Sprintf("%s IS TRUE", c.dialect.GetJSONExtract(jsonPath))
|
||||
} else {
|
||||
sqlExpr = fmt.Sprintf("NOT(%s IS TRUE)", c.dialect.GetJSONExtract(jsonPath))
|
||||
}
|
||||
} else { // operator == "!="
|
||||
if valueBool {
|
||||
sqlExpr = fmt.Sprintf("NOT(%s IS TRUE)", c.dialect.GetJSONExtract(jsonPath))
|
||||
} else {
|
||||
sqlExpr = fmt.Sprintf("%s IS TRUE", c.dialect.GetJSONExtract(jsonPath))
|
||||
}
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Special handling for MySQL - use raw operator with CAST
|
||||
if _, ok := c.dialect.(*MySQLDialect); ok {
|
||||
var sqlExpr string
|
||||
boolStr := "false"
|
||||
if valueBool {
|
||||
boolStr = "true"
|
||||
}
|
||||
sqlExpr = fmt.Sprintf("%s %s CAST('%s' AS JSON)", c.dialect.GetJSONExtract(jsonPath), operator, boolStr)
|
||||
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle PostgreSQL differently - it uses the raw operator
|
||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||
var jsonExtract string
|
||||
// Special handling for has_link, has_code, has_incomplete_tasks
|
||||
if field == "has_link" || field == "has_code" || field == "has_incomplete_tasks" {
|
||||
// Use memo-> format for these fields
|
||||
parts := strings.Split(strings.TrimPrefix(jsonPath, "$."), ".")
|
||||
jsonExtract = "memo->'payload'"
|
||||
for i, part := range parts {
|
||||
if i == len(parts)-1 {
|
||||
jsonExtract += fmt.Sprintf("->>'%s'", part)
|
||||
} else {
|
||||
jsonExtract += fmt.Sprintf("->'%s'", part)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Use standard format for has_task_list
|
||||
jsonExtract = c.dialect.GetJSONExtract(jsonPath)
|
||||
}
|
||||
|
||||
sqlExpr := fmt.Sprintf("(%s)::boolean %s %s",
|
||||
jsonExtract,
|
||||
operator,
|
||||
c.dialect.GetParameterPlaceholder(c.paramIndex))
|
||||
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueBool)
|
||||
c.paramIndex++
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle other dialects
|
||||
if operator == "!=" {
|
||||
valueBool = !valueBool
|
||||
}
|
||||
|
||||
sqlExpr := c.dialect.GetBooleanComparison(jsonPath, valueBool)
|
||||
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -85,7 +85,10 @@ func (*SQLiteDialect) GetBooleanValue(value bool) interface{} {
|
|||
}
|
||||
|
||||
func (d *SQLiteDialect) GetBooleanComparison(path string, value bool) string {
|
||||
return fmt.Sprintf("%s = %d", d.GetJSONExtract(path), d.GetBooleanValue(value))
|
||||
if value {
|
||||
return fmt.Sprintf("%s = 1", d.GetJSONExtract(path))
|
||||
}
|
||||
return fmt.Sprintf("%s = 0", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *SQLiteDialect) GetBooleanCheck(path string) string {
|
||||
|
|
@ -132,11 +135,10 @@ func (*MySQLDialect) GetBooleanValue(value bool) interface{} {
|
|||
}
|
||||
|
||||
func (d *MySQLDialect) GetBooleanComparison(path string, value bool) string {
|
||||
boolStr := "false"
|
||||
if value {
|
||||
boolStr = "true"
|
||||
return fmt.Sprintf("%s = CAST('true' AS JSON)", d.GetJSONExtract(path))
|
||||
}
|
||||
return fmt.Sprintf("%s = CAST('%s' AS JSON)", d.GetJSONExtract(path), boolStr)
|
||||
return fmt.Sprintf("%s != CAST('true' AS JSON)", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *MySQLDialect) GetBooleanCheck(path string) string {
|
||||
|
|
@ -163,7 +165,7 @@ func (*PostgreSQLDialect) GetParameterPlaceholder(index int) string {
|
|||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetJSONExtract(path string) string {
|
||||
// Convert $.property.hasTaskList to payload->'property'->>'hasTaskList'
|
||||
// Convert $.property.hasTaskList to memo.payload->'property'->>'hasTaskList'
|
||||
parts := strings.Split(strings.TrimPrefix(path, "$."), ".")
|
||||
result := fmt.Sprintf("%s.payload", d.GetTablePrefix())
|
||||
for i, part := range parts {
|
||||
|
|
@ -196,10 +198,26 @@ func (*PostgreSQLDialect) GetBooleanValue(value bool) interface{} {
|
|||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetBooleanComparison(path string, _ bool) string {
|
||||
// Note: The parameter placeholder will be replaced by the caller
|
||||
return fmt.Sprintf("(%s)::boolean = ?", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetBooleanCheck(path string) string {
|
||||
// Special handling for standalone boolean identifiers
|
||||
if strings.Contains(path, "hasLink") || strings.Contains(path, "hasCode") || strings.Contains(path, "hasIncompleteTasks") {
|
||||
// Use memo-> instead of memo.payload-> for these fields
|
||||
parts := strings.Split(strings.TrimPrefix(path, "$."), ".")
|
||||
result := fmt.Sprintf("%s->'payload'", d.GetTablePrefix())
|
||||
for i, part := range parts {
|
||||
if i == len(parts)-1 {
|
||||
result += fmt.Sprintf("->>'%s'", part)
|
||||
} else {
|
||||
result += fmt.Sprintf("->'%s'", part)
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("(%s)::boolean = true", result)
|
||||
}
|
||||
// Use default format for other fields
|
||||
return fmt.Sprintf("(%s)::boolean IS TRUE", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -26,12 +26,8 @@ service MemoService {
|
|||
}
|
||||
// ListMemos lists memos with pagination and filter.
|
||||
rpc ListMemos(ListMemosRequest) returns (ListMemosResponse) {
|
||||
option (google.api.http) = {
|
||||
get: "/api/v1/memos"
|
||||
additional_bindings: {get: "/api/v1/{parent=users/*}/memos"}
|
||||
};
|
||||
option (google.api.http) = {get: "/api/v1/memos"};
|
||||
option (google.api.method_signature) = "";
|
||||
option (google.api.method_signature) = "parent";
|
||||
}
|
||||
// GetMemo gets a memo.
|
||||
rpc GetMemo(GetMemoRequest) returns (Memo) {
|
||||
|
|
@ -276,27 +272,19 @@ message CreateMemoRequest {
|
|||
}
|
||||
|
||||
message ListMemosRequest {
|
||||
// Optional. The parent is the owner of the memos.
|
||||
// If not specified or `users/-`, it will list all memos.
|
||||
// Format: users/{user}
|
||||
string parent = 1 [
|
||||
(google.api.field_behavior) = OPTIONAL,
|
||||
(google.api.resource_reference) = {type: "memos.api.v1/User"}
|
||||
];
|
||||
|
||||
// Optional. The maximum number of memos to return.
|
||||
// The service may return fewer than this value.
|
||||
// If unspecified, at most 50 memos will be returned.
|
||||
// The maximum value is 1000; values above 1000 will be coerced to 1000.
|
||||
int32 page_size = 2 [(google.api.field_behavior) = OPTIONAL];
|
||||
int32 page_size = 1 [(google.api.field_behavior) = OPTIONAL];
|
||||
|
||||
// Optional. A page token, received from a previous `ListMemos` call.
|
||||
// Provide this to retrieve the subsequent page.
|
||||
string page_token = 3 [(google.api.field_behavior) = OPTIONAL];
|
||||
string page_token = 2 [(google.api.field_behavior) = OPTIONAL];
|
||||
|
||||
// Optional. The state of the memos to list.
|
||||
// Default to `NORMAL`. Set to `ARCHIVED` to list archived memos.
|
||||
State state = 4 [(google.api.field_behavior) = OPTIONAL];
|
||||
State state = 3 [(google.api.field_behavior) = OPTIONAL];
|
||||
|
||||
// Optional. The order to sort results by.
|
||||
// Default to "display_time desc".
|
||||
|
|
|
|||
|
|
@ -551,21 +551,17 @@ func (x *CreateMemoRequest) GetRequestId() string {
|
|||
|
||||
type ListMemosRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// Optional. The parent is the owner of the memos.
|
||||
// If not specified or `users/-`, it will list all memos.
|
||||
// Format: users/{user}
|
||||
Parent string `protobuf:"bytes,1,opt,name=parent,proto3" json:"parent,omitempty"`
|
||||
// Optional. The maximum number of memos to return.
|
||||
// The service may return fewer than this value.
|
||||
// If unspecified, at most 50 memos will be returned.
|
||||
// The maximum value is 1000; values above 1000 will be coerced to 1000.
|
||||
PageSize int32 `protobuf:"varint,2,opt,name=page_size,json=pageSize,proto3" json:"page_size,omitempty"`
|
||||
PageSize int32 `protobuf:"varint,1,opt,name=page_size,json=pageSize,proto3" json:"page_size,omitempty"`
|
||||
// Optional. A page token, received from a previous `ListMemos` call.
|
||||
// Provide this to retrieve the subsequent page.
|
||||
PageToken string `protobuf:"bytes,3,opt,name=page_token,json=pageToken,proto3" json:"page_token,omitempty"`
|
||||
PageToken string `protobuf:"bytes,2,opt,name=page_token,json=pageToken,proto3" json:"page_token,omitempty"`
|
||||
// Optional. The state of the memos to list.
|
||||
// Default to `NORMAL`. Set to `ARCHIVED` to list archived memos.
|
||||
State State `protobuf:"varint,4,opt,name=state,proto3,enum=memos.api.v1.State" json:"state,omitempty"`
|
||||
State State `protobuf:"varint,3,opt,name=state,proto3,enum=memos.api.v1.State" json:"state,omitempty"`
|
||||
// Optional. The order to sort results by.
|
||||
// Default to "display_time desc".
|
||||
// Example: "display_time desc" or "create_time asc"
|
||||
|
|
@ -610,13 +606,6 @@ func (*ListMemosRequest) Descriptor() ([]byte, []int) {
|
|||
return file_api_v1_memo_service_proto_rawDescGZIP(), []int{4}
|
||||
}
|
||||
|
||||
func (x *ListMemosRequest) GetParent() string {
|
||||
if x != nil {
|
||||
return x.Parent
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *ListMemosRequest) GetPageSize() int32 {
|
||||
if x != nil {
|
||||
return x.PageSize
|
||||
|
|
@ -2064,14 +2053,12 @@ const file_api_v1_memo_service_proto_rawDesc = "" +
|
|||
"\amemo_id\x18\x02 \x01(\tB\x03\xe0A\x01R\x06memoId\x12(\n" +
|
||||
"\rvalidate_only\x18\x03 \x01(\bB\x03\xe0A\x01R\fvalidateOnly\x12\"\n" +
|
||||
"\n" +
|
||||
"request_id\x18\x04 \x01(\tB\x03\xe0A\x01R\trequestId\"\xa0\x02\n" +
|
||||
"\x10ListMemosRequest\x121\n" +
|
||||
"\x06parent\x18\x01 \x01(\tB\x19\xe0A\x01\xfaA\x13\n" +
|
||||
"\x11memos.api.v1/UserR\x06parent\x12 \n" +
|
||||
"\tpage_size\x18\x02 \x01(\x05B\x03\xe0A\x01R\bpageSize\x12\"\n" +
|
||||
"request_id\x18\x04 \x01(\tB\x03\xe0A\x01R\trequestId\"\xed\x01\n" +
|
||||
"\x10ListMemosRequest\x12 \n" +
|
||||
"\tpage_size\x18\x01 \x01(\x05B\x03\xe0A\x01R\bpageSize\x12\"\n" +
|
||||
"\n" +
|
||||
"page_token\x18\x03 \x01(\tB\x03\xe0A\x01R\tpageToken\x12.\n" +
|
||||
"\x05state\x18\x04 \x01(\x0e2\x13.memos.api.v1.StateB\x03\xe0A\x01R\x05state\x12\x1e\n" +
|
||||
"page_token\x18\x02 \x01(\tB\x03\xe0A\x01R\tpageToken\x12.\n" +
|
||||
"\x05state\x18\x03 \x01(\x0e2\x13.memos.api.v1.StateB\x03\xe0A\x01R\x05state\x12\x1e\n" +
|
||||
"\border_by\x18\x05 \x01(\tB\x03\xe0A\x01R\aorderBy\x12\x1b\n" +
|
||||
"\x06filter\x18\x06 \x01(\tB\x03\xe0A\x01R\x06filter\x12&\n" +
|
||||
"\fshow_deleted\x18\a \x01(\bB\x03\xe0A\x01R\vshowDeleted\"\x84\x01\n" +
|
||||
|
|
@ -2187,11 +2174,11 @@ const file_api_v1_memo_service_proto_rawDesc = "" +
|
|||
"\aPRIVATE\x10\x01\x12\r\n" +
|
||||
"\tPROTECTED\x10\x02\x12\n" +
|
||||
"\n" +
|
||||
"\x06PUBLIC\x10\x032\x97\x11\n" +
|
||||
"\x06PUBLIC\x10\x032\xeb\x10\n" +
|
||||
"\vMemoService\x12e\n" +
|
||||
"\n" +
|
||||
"CreateMemo\x12\x1f.memos.api.v1.CreateMemoRequest\x1a\x12.memos.api.v1.Memo\"\"\xdaA\x04memo\x82\xd3\xe4\x93\x02\x15:\x04memo\"\r/api/v1/memos\x12\x91\x01\n" +
|
||||
"\tListMemos\x12\x1e.memos.api.v1.ListMemosRequest\x1a\x1f.memos.api.v1.ListMemosResponse\"C\xdaA\x00\xdaA\x06parent\x82\xd3\xe4\x93\x021Z \x12\x1e/api/v1/{parent=users/*}/memos\x12\r/api/v1/memos\x12b\n" +
|
||||
"CreateMemo\x12\x1f.memos.api.v1.CreateMemoRequest\x1a\x12.memos.api.v1.Memo\"\"\xdaA\x04memo\x82\xd3\xe4\x93\x02\x15:\x04memo\"\r/api/v1/memos\x12f\n" +
|
||||
"\tListMemos\x12\x1e.memos.api.v1.ListMemosRequest\x1a\x1f.memos.api.v1.ListMemosResponse\"\x18\xdaA\x00\x82\xd3\xe4\x93\x02\x0f\x12\r/api/v1/memos\x12b\n" +
|
||||
"\aGetMemo\x12\x1c.memos.api.v1.GetMemoRequest\x1a\x12.memos.api.v1.Memo\"%\xdaA\x04name\x82\xd3\xe4\x93\x02\x18\x12\x16/api/v1/{name=memos/*}\x12\x7f\n" +
|
||||
"\n" +
|
||||
"UpdateMemo\x12\x1f.memos.api.v1.UpdateMemoRequest\x1a\x12.memos.api.v1.Memo\"<\xdaA\x10memo,update_mask\x82\xd3\xe4\x93\x02#:\x04memo2\x1b/api/v1/{memo.name=memos/*}\x12l\n" +
|
||||
|
|
|
|||
|
|
@ -111,59 +111,6 @@ func local_request_MemoService_ListMemos_0(ctx context.Context, marshaler runtim
|
|||
return msg, metadata, err
|
||||
}
|
||||
|
||||
var filter_MemoService_ListMemos_1 = &utilities.DoubleArray{Encoding: map[string]int{"parent": 0}, Base: []int{1, 1, 0}, Check: []int{0, 1, 2}}
|
||||
|
||||
func request_MemoService_ListMemos_1(ctx context.Context, marshaler runtime.Marshaler, client MemoServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
|
||||
var (
|
||||
protoReq ListMemosRequest
|
||||
metadata runtime.ServerMetadata
|
||||
err error
|
||||
)
|
||||
if req.Body != nil {
|
||||
_, _ = io.Copy(io.Discard, req.Body)
|
||||
}
|
||||
val, ok := pathParams["parent"]
|
||||
if !ok {
|
||||
return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "parent")
|
||||
}
|
||||
protoReq.Parent, err = runtime.String(val)
|
||||
if err != nil {
|
||||
return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "parent", err)
|
||||
}
|
||||
if err := req.ParseForm(); err != nil {
|
||||
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
|
||||
}
|
||||
if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_MemoService_ListMemos_1); err != nil {
|
||||
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
|
||||
}
|
||||
msg, err := client.ListMemos(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD))
|
||||
return msg, metadata, err
|
||||
}
|
||||
|
||||
func local_request_MemoService_ListMemos_1(ctx context.Context, marshaler runtime.Marshaler, server MemoServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
|
||||
var (
|
||||
protoReq ListMemosRequest
|
||||
metadata runtime.ServerMetadata
|
||||
err error
|
||||
)
|
||||
val, ok := pathParams["parent"]
|
||||
if !ok {
|
||||
return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "parent")
|
||||
}
|
||||
protoReq.Parent, err = runtime.String(val)
|
||||
if err != nil {
|
||||
return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "parent", err)
|
||||
}
|
||||
if err := req.ParseForm(); err != nil {
|
||||
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
|
||||
}
|
||||
if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_MemoService_ListMemos_1); err != nil {
|
||||
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
|
||||
}
|
||||
msg, err := server.ListMemos(ctx, &protoReq)
|
||||
return msg, metadata, err
|
||||
}
|
||||
|
||||
var filter_MemoService_GetMemo_0 = &utilities.DoubleArray{Encoding: map[string]int{"name": 0}, Base: []int{1, 1, 0}, Check: []int{0, 1, 2}}
|
||||
|
||||
func request_MemoService_GetMemo_0(ctx context.Context, marshaler runtime.Marshaler, client MemoServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
|
||||
|
|
@ -956,26 +903,6 @@ func RegisterMemoServiceHandlerServer(ctx context.Context, mux *runtime.ServeMux
|
|||
}
|
||||
forward_MemoService_ListMemos_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
|
||||
})
|
||||
mux.Handle(http.MethodGet, pattern_MemoService_ListMemos_1, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
|
||||
ctx, cancel := context.WithCancel(req.Context())
|
||||
defer cancel()
|
||||
var stream runtime.ServerTransportStream
|
||||
ctx = grpc.NewContextWithServerTransportStream(ctx, &stream)
|
||||
inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
|
||||
annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/memos.api.v1.MemoService/ListMemos", runtime.WithHTTPPathPattern("/api/v1/{parent=users/*}/memos"))
|
||||
if err != nil {
|
||||
runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
|
||||
return
|
||||
}
|
||||
resp, md, err := local_request_MemoService_ListMemos_1(annotatedContext, inboundMarshaler, server, req, pathParams)
|
||||
md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer())
|
||||
annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md)
|
||||
if err != nil {
|
||||
runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err)
|
||||
return
|
||||
}
|
||||
forward_MemoService_ListMemos_1(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
|
||||
})
|
||||
mux.Handle(http.MethodGet, pattern_MemoService_GetMemo_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
|
||||
ctx, cancel := context.WithCancel(req.Context())
|
||||
defer cancel()
|
||||
|
|
@ -1330,23 +1257,6 @@ func RegisterMemoServiceHandlerClient(ctx context.Context, mux *runtime.ServeMux
|
|||
}
|
||||
forward_MemoService_ListMemos_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
|
||||
})
|
||||
mux.Handle(http.MethodGet, pattern_MemoService_ListMemos_1, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
|
||||
ctx, cancel := context.WithCancel(req.Context())
|
||||
defer cancel()
|
||||
inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
|
||||
annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/memos.api.v1.MemoService/ListMemos", runtime.WithHTTPPathPattern("/api/v1/{parent=users/*}/memos"))
|
||||
if err != nil {
|
||||
runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
|
||||
return
|
||||
}
|
||||
resp, md, err := request_MemoService_ListMemos_1(annotatedContext, inboundMarshaler, client, req, pathParams)
|
||||
annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md)
|
||||
if err != nil {
|
||||
runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err)
|
||||
return
|
||||
}
|
||||
forward_MemoService_ListMemos_1(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
|
||||
})
|
||||
mux.Handle(http.MethodGet, pattern_MemoService_GetMemo_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
|
||||
ctx, cancel := context.WithCancel(req.Context())
|
||||
defer cancel()
|
||||
|
|
@ -1591,7 +1501,6 @@ func RegisterMemoServiceHandlerClient(ctx context.Context, mux *runtime.ServeMux
|
|||
var (
|
||||
pattern_MemoService_CreateMemo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "memos"}, ""))
|
||||
pattern_MemoService_ListMemos_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "memos"}, ""))
|
||||
pattern_MemoService_ListMemos_1 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 2, 5, 3, 2, 4}, []string{"api", "v1", "users", "parent", "memos"}, ""))
|
||||
pattern_MemoService_GetMemo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 2, 5, 3}, []string{"api", "v1", "memos", "name"}, ""))
|
||||
pattern_MemoService_UpdateMemo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 2, 5, 3}, []string{"api", "v1", "memos", "memo.name"}, ""))
|
||||
pattern_MemoService_DeleteMemo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 2, 5, 3}, []string{"api", "v1", "memos", "name"}, ""))
|
||||
|
|
@ -1611,7 +1520,6 @@ var (
|
|||
var (
|
||||
forward_MemoService_CreateMemo_0 = runtime.ForwardResponseMessage
|
||||
forward_MemoService_ListMemos_0 = runtime.ForwardResponseMessage
|
||||
forward_MemoService_ListMemos_1 = runtime.ForwardResponseMessage
|
||||
forward_MemoService_GetMemo_0 = runtime.ForwardResponseMessage
|
||||
forward_MemoService_UpdateMemo_0 = runtime.ForwardResponseMessage
|
||||
forward_MemoService_DeleteMemo_0 = runtime.ForwardResponseMessage
|
||||
|
|
|
|||
|
|
@ -622,14 +622,6 @@ paths:
|
|||
description: ListMemos lists memos with pagination and filter.
|
||||
operationId: MemoService_ListMemos
|
||||
parameters:
|
||||
- name: parent
|
||||
in: query
|
||||
description: |-
|
||||
Optional. The parent is the owner of the memos.
|
||||
If not specified or `users/-`, it will list all memos.
|
||||
Format: users/{user}
|
||||
schema:
|
||||
type: string
|
||||
- name: pageSize
|
||||
in: query
|
||||
description: |-
|
||||
|
|
@ -1597,82 +1589,6 @@ paths:
|
|||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Status'
|
||||
/api/v1/users/{user}/memos:
|
||||
get:
|
||||
tags:
|
||||
- MemoService
|
||||
description: ListMemos lists memos with pagination and filter.
|
||||
operationId: MemoService_ListMemos
|
||||
parameters:
|
||||
- name: user
|
||||
in: path
|
||||
description: The user id.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: pageSize
|
||||
in: query
|
||||
description: |-
|
||||
Optional. The maximum number of memos to return.
|
||||
The service may return fewer than this value.
|
||||
If unspecified, at most 50 memos will be returned.
|
||||
The maximum value is 1000; values above 1000 will be coerced to 1000.
|
||||
schema:
|
||||
type: integer
|
||||
format: int32
|
||||
- name: pageToken
|
||||
in: query
|
||||
description: |-
|
||||
Optional. A page token, received from a previous `ListMemos` call.
|
||||
Provide this to retrieve the subsequent page.
|
||||
schema:
|
||||
type: string
|
||||
- name: state
|
||||
in: query
|
||||
description: |-
|
||||
Optional. The state of the memos to list.
|
||||
Default to `NORMAL`. Set to `ARCHIVED` to list archived memos.
|
||||
schema:
|
||||
enum:
|
||||
- STATE_UNSPECIFIED
|
||||
- NORMAL
|
||||
- ARCHIVED
|
||||
type: string
|
||||
format: enum
|
||||
- name: orderBy
|
||||
in: query
|
||||
description: |-
|
||||
Optional. The order to sort results by.
|
||||
Default to "display_time desc".
|
||||
Example: "display_time desc" or "create_time asc"
|
||||
schema:
|
||||
type: string
|
||||
- name: filter
|
||||
in: query
|
||||
description: |-
|
||||
Optional. Filter to apply to the list results.
|
||||
Filter is a CEL expression to filter memos.
|
||||
Refer to `Shortcut.filter`.
|
||||
schema:
|
||||
type: string
|
||||
- name: showDeleted
|
||||
in: query
|
||||
description: Optional. If true, show deleted memos in the response.
|
||||
schema:
|
||||
type: boolean
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ListMemosResponse'
|
||||
default:
|
||||
description: Default error response
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Status'
|
||||
/api/v1/users/{user}/sessions:
|
||||
get:
|
||||
tags:
|
||||
|
|
|
|||
|
|
@ -99,13 +99,6 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq
|
|||
// Exclude comments by default.
|
||||
ExcludeComments: true,
|
||||
}
|
||||
if request.Parent != "" && request.Parent != "users/-" {
|
||||
userID, err := ExtractUserIDFromName(request.Parent)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid parent: %v", err)
|
||||
}
|
||||
memoFind.CreatorID = &userID
|
||||
}
|
||||
if request.State == v1pb.State_ARCHIVED {
|
||||
state := store.Archived
|
||||
memoFind.RowStatus = &state
|
||||
|
|
|
|||
|
|
@ -329,7 +329,23 @@ func (s *APIV1Service) validateFilter(_ context.Context, filterStr string) error
|
|||
return errors.Wrap(err, "failed to parse filter")
|
||||
}
|
||||
convertCtx := filter.NewConvertContext()
|
||||
err = s.Store.GetDriver().ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||
|
||||
// Determine the dialect based on the actual database driver
|
||||
var dialect filter.SQLDialect
|
||||
switch s.Profile.Driver {
|
||||
case "sqlite":
|
||||
dialect = &filter.SQLiteDialect{}
|
||||
case "mysql":
|
||||
dialect = &filter.MySQLDialect{}
|
||||
case "postgres":
|
||||
dialect = &filter.PostgreSQLDialect{}
|
||||
default:
|
||||
// Default to SQLite for unknown drivers
|
||||
dialect = &filter.SQLiteDialect{}
|
||||
}
|
||||
|
||||
converter := filter.NewCommonSQLConverter(dialect)
|
||||
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to convert filter to SQL")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -59,7 +59,8 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo
|
|||
}
|
||||
convertCtx := filter.NewConvertContext()
|
||||
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
||||
if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||
converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{})
|
||||
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
condition := convertCtx.Buffer.String()
|
||||
|
|
|
|||
|
|
@ -1,357 +0,0 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
||||
return d.convertWithTemplates(ctx, expr)
|
||||
}
|
||||
|
||||
func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
||||
const dbType = filter.MySQLTemplate
|
||||
|
||||
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||
switch v.CallExpr.Function {
|
||||
case "_||_", "_&&_":
|
||||
if len(v.CallExpr.Args) != 2 {
|
||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString("("); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.convertWithTemplates(ctx, v.CallExpr.Args[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
operator := "AND"
|
||||
if v.CallExpr.Function == "_||_" {
|
||||
operator = "OR"
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.convertWithTemplates(ctx, v.CallExpr.Args[1]); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||
return err
|
||||
}
|
||||
case "!_":
|
||||
if len(v.CallExpr.Args) != 1 {
|
||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.convertWithTemplates(ctx, v.CallExpr.Args[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||
return err
|
||||
}
|
||||
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
|
||||
if len(v.CallExpr.Args) != 2 {
|
||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
// Check if the left side is a function call like size(tags)
|
||||
if leftCallExpr, ok := v.CallExpr.Args[0].ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||
if leftCallExpr.CallExpr.Function == "size" {
|
||||
// Handle size(tags) comparison
|
||||
if len(leftCallExpr.CallExpr.Args) != 1 {
|
||||
return errors.New("size function requires exactly one argument")
|
||||
}
|
||||
identifier, err := filter.GetIdentExprName(leftCallExpr.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if identifier != "tags" {
|
||||
return errors.Errorf("size function only supports 'tags' identifier, got: %s", identifier)
|
||||
}
|
||||
value, err := filter.GetExprValue(v.CallExpr.Args[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return errors.New("size comparison value must be an integer")
|
||||
}
|
||||
operator := d.getComparisonOperator(v.CallExpr.Function)
|
||||
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?",
|
||||
filter.GetSQL("json_array_length", dbType), operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) {
|
||||
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||
}
|
||||
value, err := filter.GetExprValue(v.CallExpr.Args[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
operator := d.getComparisonOperator(v.CallExpr.Function)
|
||||
|
||||
if identifier == "created_ts" || identifier == "updated_ts" {
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return errors.New("invalid integer timestamp value")
|
||||
}
|
||||
|
||||
timestampSQL := fmt.Sprintf(filter.GetSQL("timestamp_field", dbType), identifier)
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", timestampSQL, operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
} else if identifier == "visibility" || identifier == "content" {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||
}
|
||||
valueStr, ok := value.(string)
|
||||
if !ok {
|
||||
return errors.New("invalid string value")
|
||||
}
|
||||
|
||||
var sqlTemplate string
|
||||
if identifier == "visibility" {
|
||||
sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".`visibility`"
|
||||
} else if identifier == "content" {
|
||||
sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".`content`"
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", sqlTemplate, operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueStr)
|
||||
} else if identifier == "creator_id" {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||
}
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return errors.New("invalid int value")
|
||||
}
|
||||
|
||||
sqlTemplate := filter.GetSQL("table_prefix", dbType) + ".`creator_id`"
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", sqlTemplate, operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
} else if identifier == "has_task_list" {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||
}
|
||||
valueBool, ok := value.(bool)
|
||||
if !ok {
|
||||
return errors.New("invalid boolean value for has_task_list")
|
||||
}
|
||||
// Use template for boolean comparison
|
||||
var sqlTemplate string
|
||||
if operator == "=" {
|
||||
if valueBool {
|
||||
sqlTemplate = filter.GetSQL("boolean_true", dbType)
|
||||
} else {
|
||||
sqlTemplate = filter.GetSQL("boolean_false", dbType)
|
||||
}
|
||||
} else { // operator == "!="
|
||||
if valueBool {
|
||||
sqlTemplate = filter.GetSQL("boolean_not_true", dbType)
|
||||
} else {
|
||||
sqlTemplate = filter.GetSQL("boolean_not_false", dbType)
|
||||
}
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if identifier == "has_link" || identifier == "has_code" || identifier == "has_incomplete_tasks" {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||
}
|
||||
valueBool, ok := value.(bool)
|
||||
if !ok {
|
||||
return errors.Errorf("invalid boolean value for %s", identifier)
|
||||
}
|
||||
|
||||
// Map identifier to JSON path
|
||||
var jsonPath string
|
||||
switch identifier {
|
||||
case "has_link":
|
||||
jsonPath = "$.property.hasLink"
|
||||
case "has_code":
|
||||
jsonPath = "$.property.hasCode"
|
||||
case "has_incomplete_tasks":
|
||||
jsonPath = "$.property.hasIncompleteTasks"
|
||||
}
|
||||
|
||||
// Use JSON_EXTRACT for boolean comparison like has_task_list
|
||||
var sqlTemplate string
|
||||
if operator == "=" {
|
||||
if valueBool {
|
||||
sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') = CAST('true' AS JSON)", jsonPath)
|
||||
} else {
|
||||
sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') = CAST('false' AS JSON)", jsonPath)
|
||||
}
|
||||
} else { // operator == "!="
|
||||
if valueBool {
|
||||
sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') != CAST('true' AS JSON)", jsonPath)
|
||||
} else {
|
||||
sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') != CAST('false' AS JSON)", jsonPath)
|
||||
}
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case "@in":
|
||||
if len(v.CallExpr.Args) != 2 {
|
||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
|
||||
// Check if this is "element in collection" syntax
|
||||
if identifier, err := filter.GetIdentExprName(v.CallExpr.Args[1]); err == nil {
|
||||
// This is "element in collection" - the second argument is the collection
|
||||
if !slices.Contains([]string{"tags"}, identifier) {
|
||||
return errors.Errorf("invalid collection identifier for %s: %s", v.CallExpr.Function, identifier)
|
||||
}
|
||||
|
||||
if identifier == "tags" {
|
||||
// Handle "element" in tags
|
||||
element, err := filter.GetConstValue(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return errors.Errorf("first argument must be a constant value for 'element in tags': %v", err)
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(filter.GetSQL("json_contains_element", dbType)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, filter.GetParameterValue(dbType, "json_contains_element", element))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Original logic for "identifier in [list]" syntax
|
||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !slices.Contains([]string{"tag", "visibility"}, identifier) {
|
||||
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||
}
|
||||
|
||||
values := []any{}
|
||||
for _, element := range v.CallExpr.Args[1].GetListExpr().Elements {
|
||||
value, err := filter.GetConstValue(element)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
values = append(values, value)
|
||||
}
|
||||
if identifier == "tag" {
|
||||
subconditions := []string{}
|
||||
args := []any{}
|
||||
for _, v := range values {
|
||||
subconditions = append(subconditions, filter.GetSQL("json_contains_tag", dbType))
|
||||
args = append(args, filter.GetParameterValue(dbType, "json_contains_tag", v))
|
||||
}
|
||||
if len(subconditions) == 1 {
|
||||
if _, err := ctx.Buffer.WriteString(subconditions[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subconditions, " OR "))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
ctx.Args = append(ctx.Args, args...)
|
||||
} else if identifier == "visibility" {
|
||||
placeholders := filter.FormatPlaceholders(dbType, len(values), 1)
|
||||
visibilitySQL := fmt.Sprintf(filter.GetSQL("visibility_in", dbType), strings.Join(placeholders, ","))
|
||||
if _, err := ctx.Buffer.WriteString(visibilitySQL); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, values...)
|
||||
}
|
||||
case "contains":
|
||||
if len(v.CallExpr.Args) != 1 {
|
||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if identifier != "content" {
|
||||
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||
}
|
||||
arg, err := filter.GetConstValue(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(filter.GetSQL("content_like", dbType)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg))
|
||||
}
|
||||
} else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok {
|
||||
identifier := v.IdentExpr.GetName()
|
||||
if !slices.Contains([]string{"pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) {
|
||||
return errors.Errorf("invalid identifier %s", identifier)
|
||||
}
|
||||
if identifier == "pinned" {
|
||||
if _, err := ctx.Buffer.WriteString(filter.GetSQL("table_prefix", dbType) + ".`pinned` IS TRUE"); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if identifier == "has_task_list" {
|
||||
// Handle has_task_list as a standalone boolean identifier
|
||||
if _, err := ctx.Buffer.WriteString(filter.GetSQL("boolean_check", dbType)); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if identifier == "has_link" {
|
||||
// Handle has_link as a standalone boolean identifier
|
||||
if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') = CAST('true' AS JSON)"); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if identifier == "has_code" {
|
||||
// Handle has_code as a standalone boolean identifier
|
||||
if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') = CAST('true' AS JSON)"); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if identifier == "has_incomplete_tasks" {
|
||||
// Handle has_incomplete_tasks as a standalone boolean identifier
|
||||
if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') = CAST('true' AS JSON)"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*DB) getComparisonOperator(function string) string {
|
||||
switch function {
|
||||
case "_==_":
|
||||
return "="
|
||||
case "_!=_":
|
||||
return "!="
|
||||
case "_<_":
|
||||
return "<"
|
||||
case "_>_":
|
||||
return ">"
|
||||
case "_<=_":
|
||||
return "<="
|
||||
case "_>=_":
|
||||
return ">="
|
||||
default:
|
||||
return "="
|
||||
}
|
||||
}
|
||||
|
|
@ -148,11 +148,11 @@ func TestConvertExprToSQL(t *testing.T) {
|
|||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
db := &DB{}
|
||||
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
|
||||
require.NoError(t, err)
|
||||
convertCtx := filter.NewConvertContext()
|
||||
err = db.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||
converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{})
|
||||
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
||||
require.Equal(t, tt.args, convertCtx.Args)
|
||||
|
|
|
|||
|
|
@ -51,7 +51,8 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
|
|||
}
|
||||
convertCtx := filter.NewConvertContext()
|
||||
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
||||
if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||
converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{})
|
||||
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
condition := convertCtx.Buffer.String()
|
||||
|
|
|
|||
|
|
@ -51,7 +51,8 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo
|
|||
convertCtx := filter.NewConvertContext()
|
||||
convertCtx.ArgsOffset = len(args)
|
||||
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
||||
if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||
converter := filter.NewCommonSQLConverterWithOffset(&filter.PostgreSQLDialect{}, convertCtx.ArgsOffset+len(convertCtx.Args))
|
||||
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
condition := convertCtx.Buffer.String()
|
||||
|
|
|
|||
|
|
@ -1,373 +0,0 @@
|
|||
package postgres
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
||||
const dbType = filter.PostgreSQLTemplate
|
||||
_, err := d.convertWithParameterIndex(ctx, expr, dbType, ctx.ArgsOffset+len(ctx.Args)+1)
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *DB) convertWithParameterIndex(ctx *filter.ConvertContext, expr *exprv1.Expr, dbType filter.TemplateDBType, paramIndex int) (int, error) {
|
||||
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||
switch v.CallExpr.Function {
|
||||
case "_||_", "_&&_":
|
||||
if len(v.CallExpr.Args) != 2 {
|
||||
return paramIndex, errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString("("); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
newParamIndex, err := d.convertWithParameterIndex(ctx, v.CallExpr.Args[0], dbType, paramIndex)
|
||||
if err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
operator := "AND"
|
||||
if v.CallExpr.Function == "_||_" {
|
||||
operator = "OR"
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
newParamIndex, err = d.convertWithParameterIndex(ctx, v.CallExpr.Args[1], dbType, newParamIndex)
|
||||
if err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
return newParamIndex, nil
|
||||
case "!_":
|
||||
if len(v.CallExpr.Args) != 1 {
|
||||
return paramIndex, errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
newParamIndex, err := d.convertWithParameterIndex(ctx, v.CallExpr.Args[0], dbType, paramIndex)
|
||||
if err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
return newParamIndex, nil
|
||||
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
|
||||
if len(v.CallExpr.Args) != 2 {
|
||||
return paramIndex, errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
// Check if the left side is a function call like size(tags)
|
||||
if leftCallExpr, ok := v.CallExpr.Args[0].ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||
if leftCallExpr.CallExpr.Function == "size" {
|
||||
// Handle size(tags) comparison
|
||||
if len(leftCallExpr.CallExpr.Args) != 1 {
|
||||
return paramIndex, errors.New("size function requires exactly one argument")
|
||||
}
|
||||
identifier, err := filter.GetIdentExprName(leftCallExpr.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
if identifier != "tags" {
|
||||
return paramIndex, errors.Errorf("size function only supports 'tags' identifier, got: %s", identifier)
|
||||
}
|
||||
value, err := filter.GetExprValue(v.CallExpr.Args[1])
|
||||
if err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return paramIndex, errors.New("size comparison value must be an integer")
|
||||
}
|
||||
operator := d.getComparisonOperator(v.CallExpr.Function)
|
||||
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s",
|
||||
filter.GetSQL("json_array_length", dbType), operator,
|
||||
filter.GetParameterPlaceholder(dbType, paramIndex))); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
return paramIndex + 1, nil
|
||||
}
|
||||
}
|
||||
|
||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) {
|
||||
return paramIndex, errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||
}
|
||||
value, err := filter.GetExprValue(v.CallExpr.Args[1])
|
||||
if err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
operator := d.getComparisonOperator(v.CallExpr.Function)
|
||||
|
||||
if identifier == "created_ts" || identifier == "updated_ts" {
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return paramIndex, errors.New("invalid integer timestamp value")
|
||||
}
|
||||
|
||||
timestampSQL := fmt.Sprintf(filter.GetSQL("timestamp_field", dbType), identifier)
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", timestampSQL, operator,
|
||||
filter.GetParameterPlaceholder(dbType, paramIndex))); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
return paramIndex + 1, nil
|
||||
} else if identifier == "visibility" || identifier == "content" {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return paramIndex, errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||
}
|
||||
valueStr, ok := value.(string)
|
||||
if !ok {
|
||||
return paramIndex, errors.New("invalid string value")
|
||||
}
|
||||
|
||||
var sqlTemplate string
|
||||
if identifier == "visibility" {
|
||||
sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".visibility"
|
||||
} else if identifier == "content" {
|
||||
sqlTemplate = filter.GetSQL("content_like", dbType)
|
||||
if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", valueStr))
|
||||
return paramIndex + 1, nil
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", sqlTemplate, operator,
|
||||
filter.GetParameterPlaceholder(dbType, paramIndex))); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueStr)
|
||||
return paramIndex + 1, nil
|
||||
} else if identifier == "creator_id" {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return paramIndex, errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||
}
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return paramIndex, errors.New("invalid int value")
|
||||
}
|
||||
|
||||
sqlTemplate := filter.GetSQL("table_prefix", dbType) + ".creator_id"
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", sqlTemplate, operator,
|
||||
filter.GetParameterPlaceholder(dbType, paramIndex))); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
return paramIndex + 1, nil
|
||||
} else if identifier == "has_task_list" {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return paramIndex, errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||
}
|
||||
valueBool, ok := value.(bool)
|
||||
if !ok {
|
||||
return paramIndex, errors.New("invalid boolean value for has_task_list")
|
||||
}
|
||||
// Use parameterized template for boolean comparison (PostgreSQL only)
|
||||
placeholder := filter.GetParameterPlaceholder(dbType, paramIndex)
|
||||
sqlTemplate := fmt.Sprintf(filter.GetSQL("boolean_compare", dbType), operator)
|
||||
sqlTemplate = strings.Replace(sqlTemplate, "?", placeholder, 1)
|
||||
if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueBool)
|
||||
return paramIndex + 1, nil
|
||||
} else if identifier == "has_link" || identifier == "has_code" || identifier == "has_incomplete_tasks" {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return paramIndex, errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||
}
|
||||
valueBool, ok := value.(bool)
|
||||
if !ok {
|
||||
return paramIndex, errors.Errorf("invalid boolean value for %s", identifier)
|
||||
}
|
||||
|
||||
// Map identifier to JSON path
|
||||
var jsonPath string
|
||||
switch identifier {
|
||||
case "has_link":
|
||||
jsonPath = "$.property.hasLink"
|
||||
case "has_code":
|
||||
jsonPath = "$.property.hasCode"
|
||||
case "has_incomplete_tasks":
|
||||
jsonPath = "$.property.hasIncompleteTasks"
|
||||
}
|
||||
|
||||
// Use JSON path for boolean comparison with PostgreSQL parameter placeholder
|
||||
placeholder := filter.GetParameterPlaceholder(dbType, paramIndex)
|
||||
var sqlTemplate string
|
||||
if operator == "=" {
|
||||
sqlTemplate = fmt.Sprintf("(%s->'payload'->'property'->>'%s')::boolean = %s", filter.GetSQL("table_prefix", dbType), strings.TrimPrefix(jsonPath, "$.property."), placeholder)
|
||||
} else { // operator == "!="
|
||||
sqlTemplate = fmt.Sprintf("(%s->'payload'->'property'->>'%s')::boolean != %s", filter.GetSQL("table_prefix", dbType), strings.TrimPrefix(jsonPath, "$.property."), placeholder)
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueBool)
|
||||
return paramIndex + 1, nil
|
||||
}
|
||||
case "@in":
|
||||
if len(v.CallExpr.Args) != 2 {
|
||||
return paramIndex, errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
|
||||
// Check if this is "element in collection" syntax
|
||||
if identifier, err := filter.GetIdentExprName(v.CallExpr.Args[1]); err == nil {
|
||||
// This is "element in collection" - the second argument is the collection
|
||||
if !slices.Contains([]string{"tags"}, identifier) {
|
||||
return paramIndex, errors.Errorf("invalid collection identifier for %s: %s", v.CallExpr.Function, identifier)
|
||||
}
|
||||
|
||||
if identifier == "tags" {
|
||||
// Handle "element" in tags
|
||||
element, err := filter.GetConstValue(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return paramIndex, errors.Errorf("first argument must be a constant value for 'element in tags': %v", err)
|
||||
}
|
||||
placeholder := filter.GetParameterPlaceholder(dbType, paramIndex)
|
||||
sql := strings.Replace(filter.GetSQL("json_contains_element", dbType), "?", placeholder, 1)
|
||||
if _, err := ctx.Buffer.WriteString(sql); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, filter.GetParameterValue(dbType, "json_contains_element", element))
|
||||
return paramIndex + 1, nil
|
||||
}
|
||||
return paramIndex, nil
|
||||
}
|
||||
|
||||
// Original logic for "identifier in [list]" syntax
|
||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
if !slices.Contains([]string{"tag", "visibility"}, identifier) {
|
||||
return paramIndex, errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||
}
|
||||
|
||||
values := []any{}
|
||||
for _, element := range v.CallExpr.Args[1].GetListExpr().Elements {
|
||||
value, err := filter.GetConstValue(element)
|
||||
if err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
values = append(values, value)
|
||||
}
|
||||
if identifier == "tag" {
|
||||
subconditions := []string{}
|
||||
args := []any{}
|
||||
currentParamIndex := paramIndex
|
||||
for _, v := range values {
|
||||
// Use parameter index for each placeholder
|
||||
placeholder := filter.GetParameterPlaceholder(dbType, currentParamIndex)
|
||||
subcondition := strings.Replace(filter.GetSQL("json_contains_tag", dbType), "?", placeholder, 1)
|
||||
subconditions = append(subconditions, subcondition)
|
||||
args = append(args, filter.GetParameterValue(dbType, "json_contains_tag", v))
|
||||
currentParamIndex++
|
||||
}
|
||||
if len(subconditions) == 1 {
|
||||
if _, err := ctx.Buffer.WriteString(subconditions[0]); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
} else {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subconditions, " OR "))); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
}
|
||||
ctx.Args = append(ctx.Args, args...)
|
||||
return paramIndex + len(args), nil
|
||||
} else if identifier == "visibility" {
|
||||
placeholders := filter.FormatPlaceholders(dbType, len(values), paramIndex)
|
||||
visibilitySQL := fmt.Sprintf(filter.GetSQL("visibility_in", dbType), strings.Join(placeholders, ","))
|
||||
if _, err := ctx.Buffer.WriteString(visibilitySQL); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, values...)
|
||||
return paramIndex + len(values), nil
|
||||
}
|
||||
case "contains":
|
||||
if len(v.CallExpr.Args) != 1 {
|
||||
return paramIndex, errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Target)
|
||||
if err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
if identifier != "content" {
|
||||
return paramIndex, errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||
}
|
||||
arg, err := filter.GetConstValue(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
placeholder := filter.GetParameterPlaceholder(dbType, paramIndex)
|
||||
sql := strings.Replace(filter.GetSQL("content_like", dbType), "?", placeholder, 1)
|
||||
if _, err := ctx.Buffer.WriteString(sql); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg))
|
||||
return paramIndex + 1, nil
|
||||
}
|
||||
} else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok {
|
||||
identifier := v.IdentExpr.GetName()
|
||||
if !slices.Contains([]string{"pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) {
|
||||
return paramIndex, errors.Errorf("invalid identifier %s", identifier)
|
||||
}
|
||||
if identifier == "pinned" {
|
||||
if _, err := ctx.Buffer.WriteString(filter.GetSQL("table_prefix", dbType) + ".pinned IS TRUE"); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
} else if identifier == "has_task_list" {
|
||||
// Handle has_task_list as a standalone boolean identifier
|
||||
if _, err := ctx.Buffer.WriteString(filter.GetSQL("boolean_check", dbType)); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
} else if identifier == "has_link" {
|
||||
// Handle has_link as a standalone boolean identifier
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s->'payload'->'property'->>'hasLink')::boolean = true", filter.GetSQL("table_prefix", dbType))); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
} else if identifier == "has_code" {
|
||||
// Handle has_code as a standalone boolean identifier
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s->'payload'->'property'->>'hasCode')::boolean = true", filter.GetSQL("table_prefix", dbType))); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
} else if identifier == "has_incomplete_tasks" {
|
||||
// Handle has_incomplete_tasks as a standalone boolean identifier
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s->'payload'->'property'->>'hasIncompleteTasks')::boolean = true", filter.GetSQL("table_prefix", dbType))); err != nil {
|
||||
return paramIndex, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return paramIndex, nil
|
||||
}
|
||||
|
||||
func (*DB) getComparisonOperator(function string) string {
|
||||
switch function {
|
||||
case "_==_":
|
||||
return "="
|
||||
case "_!=_":
|
||||
return "!="
|
||||
case "_<_":
|
||||
return "<"
|
||||
case "_>_":
|
||||
return ">"
|
||||
case "_<=_":
|
||||
return "<="
|
||||
case "_>=_":
|
||||
return ">="
|
||||
default:
|
||||
return "="
|
||||
}
|
||||
}
|
||||
|
|
@ -148,11 +148,11 @@ func TestConvertExprToSQL(t *testing.T) {
|
|||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
db := &DB{}
|
||||
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
|
||||
require.NoError(t, err)
|
||||
convertCtx := filter.NewConvertContext()
|
||||
err = db.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||
converter := filter.NewCommonSQLConverterWithOffset(&filter.PostgreSQLDialect{}, convertCtx.ArgsOffset+len(convertCtx.Args))
|
||||
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
||||
require.Equal(t, tt.args, convertCtx.Args)
|
||||
|
|
|
|||
|
|
@ -58,7 +58,8 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
|
|||
convertCtx := filter.NewConvertContext()
|
||||
convertCtx.ArgsOffset = len(args)
|
||||
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
||||
if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||
converter := filter.NewCommonSQLConverterWithOffset(&filter.PostgreSQLDialect{}, convertCtx.ArgsOffset+len(convertCtx.Args))
|
||||
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
condition := convertCtx.Buffer.String()
|
||||
|
|
|
|||
|
|
@ -51,7 +51,8 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo
|
|||
}
|
||||
convertCtx := filter.NewConvertContext()
|
||||
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
||||
if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||
converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{})
|
||||
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
condition := convertCtx.Buffer.String()
|
||||
|
|
|
|||
|
|
@ -1,357 +0,0 @@
|
|||
package sqlite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
||||
return d.convertWithTemplates(ctx, expr)
|
||||
}
|
||||
|
||||
func (d *DB) convertWithTemplates(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
||||
const dbType = filter.SQLiteTemplate
|
||||
|
||||
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||
switch v.CallExpr.Function {
|
||||
case "_||_", "_&&_":
|
||||
if len(v.CallExpr.Args) != 2 {
|
||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString("("); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.convertWithTemplates(ctx, v.CallExpr.Args[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
operator := "AND"
|
||||
if v.CallExpr.Function == "_||_" {
|
||||
operator = "OR"
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.convertWithTemplates(ctx, v.CallExpr.Args[1]); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||
return err
|
||||
}
|
||||
case "!_":
|
||||
if len(v.CallExpr.Args) != 1 {
|
||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.convertWithTemplates(ctx, v.CallExpr.Args[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||
return err
|
||||
}
|
||||
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
|
||||
if len(v.CallExpr.Args) != 2 {
|
||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
// Check if the left side is a function call like size(tags)
|
||||
if leftCallExpr, ok := v.CallExpr.Args[0].ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||
if leftCallExpr.CallExpr.Function == "size" {
|
||||
// Handle size(tags) comparison
|
||||
if len(leftCallExpr.CallExpr.Args) != 1 {
|
||||
return errors.New("size function requires exactly one argument")
|
||||
}
|
||||
identifier, err := filter.GetIdentExprName(leftCallExpr.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if identifier != "tags" {
|
||||
return errors.Errorf("size function only supports 'tags' identifier, got: %s", identifier)
|
||||
}
|
||||
value, err := filter.GetExprValue(v.CallExpr.Args[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return errors.New("size comparison value must be an integer")
|
||||
}
|
||||
operator := d.getComparisonOperator(v.CallExpr.Function)
|
||||
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?",
|
||||
filter.GetSQL("json_array_length", dbType), operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) {
|
||||
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||
}
|
||||
value, err := filter.GetExprValue(v.CallExpr.Args[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
operator := d.getComparisonOperator(v.CallExpr.Function)
|
||||
|
||||
if identifier == "created_ts" || identifier == "updated_ts" {
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return errors.New("invalid integer timestamp value")
|
||||
}
|
||||
|
||||
timestampSQL := fmt.Sprintf(filter.GetSQL("timestamp_field", dbType), identifier)
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", timestampSQL, operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
} else if identifier == "visibility" || identifier == "content" {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||
}
|
||||
valueStr, ok := value.(string)
|
||||
if !ok {
|
||||
return errors.New("invalid string value")
|
||||
}
|
||||
|
||||
var sqlTemplate string
|
||||
if identifier == "visibility" {
|
||||
sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".`visibility`"
|
||||
} else if identifier == "content" {
|
||||
sqlTemplate = filter.GetSQL("table_prefix", dbType) + ".`content`"
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", sqlTemplate, operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueStr)
|
||||
} else if identifier == "creator_id" {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||
}
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return errors.New("invalid int value")
|
||||
}
|
||||
|
||||
sqlTemplate := filter.GetSQL("table_prefix", dbType) + ".`creator_id`"
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", sqlTemplate, operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
} else if identifier == "has_task_list" {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||
}
|
||||
valueBool, ok := value.(bool)
|
||||
if !ok {
|
||||
return errors.New("invalid boolean value for has_task_list")
|
||||
}
|
||||
// Use template for boolean comparison
|
||||
var sqlTemplate string
|
||||
if operator == "=" {
|
||||
if valueBool {
|
||||
sqlTemplate = filter.GetSQL("boolean_true", dbType)
|
||||
} else {
|
||||
sqlTemplate = filter.GetSQL("boolean_false", dbType)
|
||||
}
|
||||
} else { // operator == "!="
|
||||
if valueBool {
|
||||
sqlTemplate = filter.GetSQL("boolean_not_true", dbType)
|
||||
} else {
|
||||
sqlTemplate = filter.GetSQL("boolean_not_false", dbType)
|
||||
}
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if identifier == "has_link" || identifier == "has_code" || identifier == "has_incomplete_tasks" {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
|
||||
}
|
||||
valueBool, ok := value.(bool)
|
||||
if !ok {
|
||||
return errors.Errorf("invalid boolean value for %s", identifier)
|
||||
}
|
||||
|
||||
// Map identifier to JSON path
|
||||
var jsonPath string
|
||||
switch identifier {
|
||||
case "has_link":
|
||||
jsonPath = "$.property.hasLink"
|
||||
case "has_code":
|
||||
jsonPath = "$.property.hasCode"
|
||||
case "has_incomplete_tasks":
|
||||
jsonPath = "$.property.hasIncompleteTasks"
|
||||
}
|
||||
|
||||
// Use JSON_EXTRACT for boolean comparison like has_task_list
|
||||
var sqlTemplate string
|
||||
if operator == "=" {
|
||||
if valueBool {
|
||||
sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') IS TRUE", jsonPath)
|
||||
} else {
|
||||
sqlTemplate = fmt.Sprintf("NOT(JSON_EXTRACT(`memo`.`payload`, '%s') IS TRUE)", jsonPath)
|
||||
}
|
||||
} else { // operator == "!="
|
||||
if valueBool {
|
||||
sqlTemplate = fmt.Sprintf("NOT(JSON_EXTRACT(`memo`.`payload`, '%s') IS TRUE)", jsonPath)
|
||||
} else {
|
||||
sqlTemplate = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '%s') IS TRUE", jsonPath)
|
||||
}
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(sqlTemplate); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case "@in":
|
||||
if len(v.CallExpr.Args) != 2 {
|
||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
|
||||
// Check if this is "element in collection" syntax
|
||||
if identifier, err := filter.GetIdentExprName(v.CallExpr.Args[1]); err == nil {
|
||||
// This is "element in collection" - the second argument is the collection
|
||||
if !slices.Contains([]string{"tags"}, identifier) {
|
||||
return errors.Errorf("invalid collection identifier for %s: %s", v.CallExpr.Function, identifier)
|
||||
}
|
||||
|
||||
if identifier == "tags" {
|
||||
// Handle "element" in tags
|
||||
element, err := filter.GetConstValue(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return errors.Errorf("first argument must be a constant value for 'element in tags': %v", err)
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(filter.GetSQL("json_contains_element", dbType)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, filter.GetParameterValue(dbType, "json_contains_element", element))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Original logic for "identifier in [list]" syntax
|
||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !slices.Contains([]string{"tag", "visibility"}, identifier) {
|
||||
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||
}
|
||||
|
||||
values := []any{}
|
||||
for _, element := range v.CallExpr.Args[1].GetListExpr().Elements {
|
||||
value, err := filter.GetConstValue(element)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
values = append(values, value)
|
||||
}
|
||||
if identifier == "tag" {
|
||||
subconditions := []string{}
|
||||
args := []any{}
|
||||
for _, v := range values {
|
||||
subconditions = append(subconditions, filter.GetSQL("json_contains_tag", dbType))
|
||||
args = append(args, filter.GetParameterValue(dbType, "json_contains_tag", v))
|
||||
}
|
||||
if len(subconditions) == 1 {
|
||||
if _, err := ctx.Buffer.WriteString(subconditions[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subconditions, " OR "))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
ctx.Args = append(ctx.Args, args...)
|
||||
} else if identifier == "visibility" {
|
||||
placeholders := filter.FormatPlaceholders(dbType, len(values), 1)
|
||||
visibilitySQL := fmt.Sprintf(filter.GetSQL("visibility_in", dbType), strings.Join(placeholders, ","))
|
||||
if _, err := ctx.Buffer.WriteString(visibilitySQL); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, values...)
|
||||
}
|
||||
case "contains":
|
||||
if len(v.CallExpr.Args) != 1 {
|
||||
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
identifier, err := filter.GetIdentExprName(v.CallExpr.Target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if identifier != "content" {
|
||||
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||
}
|
||||
arg, err := filter.GetConstValue(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(filter.GetSQL("content_like", dbType)); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg))
|
||||
}
|
||||
} else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok {
|
||||
identifier := v.IdentExpr.GetName()
|
||||
if !slices.Contains([]string{"pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) {
|
||||
return errors.Errorf("invalid identifier %s", identifier)
|
||||
}
|
||||
if identifier == "pinned" {
|
||||
if _, err := ctx.Buffer.WriteString(filter.GetSQL("table_prefix", dbType) + ".`pinned` IS TRUE"); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if identifier == "has_task_list" {
|
||||
// Handle has_task_list as a standalone boolean identifier
|
||||
if _, err := ctx.Buffer.WriteString(filter.GetSQL("boolean_check", dbType)); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if identifier == "has_link" {
|
||||
// Handle has_link as a standalone boolean identifier
|
||||
if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasLink') IS TRUE"); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if identifier == "has_code" {
|
||||
// Handle has_code as a standalone boolean identifier
|
||||
if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasCode') IS TRUE"); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if identifier == "has_incomplete_tasks" {
|
||||
// Handle has_incomplete_tasks as a standalone boolean identifier
|
||||
if _, err := ctx.Buffer.WriteString("JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') IS TRUE"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*DB) getComparisonOperator(function string) string {
|
||||
switch function {
|
||||
case "_==_":
|
||||
return "="
|
||||
case "_!=_":
|
||||
return "!="
|
||||
case "_<_":
|
||||
return "<"
|
||||
case "_>_":
|
||||
return ">"
|
||||
case "_<=_":
|
||||
return "<="
|
||||
case "_>=_":
|
||||
return ">="
|
||||
default:
|
||||
return "="
|
||||
}
|
||||
}
|
||||
|
|
@ -153,11 +153,11 @@ func TestConvertExprToSQL(t *testing.T) {
|
|||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
db := &DB{}
|
||||
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
|
||||
require.NoError(t, err)
|
||||
convertCtx := filter.NewConvertContext()
|
||||
err = db.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||
converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{})
|
||||
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
||||
require.Equal(t, tt.args, convertCtx.Args)
|
||||
|
|
|
|||
|
|
@ -57,7 +57,8 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
|
|||
}
|
||||
convertCtx := filter.NewConvertContext()
|
||||
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
||||
if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||
converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{})
|
||||
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
condition := convertCtx.Buffer.String()
|
||||
|
|
|
|||
|
|
@ -3,10 +3,6 @@ package store
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
// Driver is an interface for store driver.
|
||||
|
|
@ -73,7 +69,4 @@ type Driver interface {
|
|||
UpsertReaction(ctx context.Context, create *Reaction) (*Reaction, error)
|
||||
ListReactions(ctx context.Context, find *FindReaction) ([]*Reaction, error)
|
||||
DeleteReaction(ctx context.Context, delete *DeleteReaction) error
|
||||
|
||||
// Shortcut related methods.
|
||||
ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error
|
||||
}
|
||||
|
|
|
|||
|
|
@ -47,12 +47,14 @@ const AddMemoRelationPopover = (props: Props) => {
|
|||
setIsFetching(true);
|
||||
try {
|
||||
const conditions = [];
|
||||
// Extract user ID from user name (format: users/{user_id})
|
||||
const userId = user.name.replace("users/", "");
|
||||
conditions.push(`creator_id == ${userId}`);
|
||||
if (searchText) {
|
||||
conditions.push(`content.contains("${searchText}")`);
|
||||
}
|
||||
const { memos } = await memoServiceClient.listMemos({
|
||||
parent: user.name,
|
||||
filter: conditions.length > 0 ? conditions.join(" && ") : undefined,
|
||||
filter: conditions.join(" && "),
|
||||
pageSize: DEFAULT_LIST_MEMOS_PAGE_SIZE,
|
||||
});
|
||||
setFetchedMemos(memos);
|
||||
|
|
|
|||
|
|
@ -47,11 +47,20 @@ const PagedMemoList = observer((props: Props) => {
|
|||
setIsRequesting(true);
|
||||
|
||||
try {
|
||||
const filters = [];
|
||||
if (props.owner) {
|
||||
// Extract user ID from owner name (format: users/{user_id})
|
||||
const userId = props.owner.replace("users/", "");
|
||||
filters.push(`creator_id == ${userId}`);
|
||||
}
|
||||
if (props.filter) {
|
||||
filters.push(props.filter);
|
||||
}
|
||||
|
||||
const response = await memoStore.fetchMemos({
|
||||
parent: props.owner || "",
|
||||
state: props.state || State.NORMAL,
|
||||
orderBy: props.orderBy || "display_time desc",
|
||||
filter: props.filter || "",
|
||||
filter: filters.length > 0 ? filters.join(" && ") : undefined,
|
||||
pageSize: props.pageSize || DEFAULT_LIST_MEMOS_PAGE_SIZE,
|
||||
pageToken,
|
||||
});
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import MemoView from "@/components/MemoView";
|
|||
import PagedMemoList from "@/components/PagedMemoList";
|
||||
import useCurrentUser from "@/hooks/useCurrentUser";
|
||||
import { viewStore, userStore, workspaceStore } from "@/store";
|
||||
import { extractUserIdFromName } from "@/store/common";
|
||||
import memoFilterStore from "@/store/memoFilter";
|
||||
import { State } from "@/types/proto/api/v1/common";
|
||||
import { Memo } from "@/types/proto/api/v1/memo_service";
|
||||
|
|
@ -22,7 +23,7 @@ const Home = observer(() => {
|
|||
const selectedShortcut = userStore.state.shortcuts.find((shortcut) => getShortcutId(shortcut.name) === memoFilterStore.shortcut);
|
||||
|
||||
const memoFilter = useMemo(() => {
|
||||
const conditions = [];
|
||||
const conditions = [`creator_id == "${extractUserIdFromName(user.name)}"`];
|
||||
if (selectedShortcut?.filter) {
|
||||
conditions.push(selectedShortcut.filter);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import UserAvatar from "@/components/UserAvatar";
|
|||
import { Button } from "@/components/ui/button";
|
||||
import useLoading from "@/hooks/useLoading";
|
||||
import { viewStore, userStore } from "@/store";
|
||||
import { extractUserIdFromName } from "@/store/common";
|
||||
import memoFilterStore from "@/store/memoFilter";
|
||||
import { State } from "@/types/proto/api/v1/common";
|
||||
import { Memo } from "@/types/proto/api/v1/memo_service";
|
||||
|
|
@ -46,7 +47,7 @@ const UserProfile = observer(() => {
|
|||
return undefined;
|
||||
}
|
||||
|
||||
const conditions = [];
|
||||
const conditions = [`creator_id == "${extractUserIdFromName(user.name)}"`];
|
||||
for (const filter of memoFilterStore.filters) {
|
||||
if (filter.factor === "contentSearch") {
|
||||
conditions.push(`content.contains("${filter.value}")`);
|
||||
|
|
|
|||
|
|
@ -4,6 +4,10 @@ export const memoNamePrefix = "memos/";
|
|||
export const identityProviderNamePrefix = "identityProviders/";
|
||||
export const activityNamePrefix = "activities/";
|
||||
|
||||
export const extractUserIdFromName = (name: string) => {
|
||||
return name.split(userNamePrefix).pop() || "";
|
||||
};
|
||||
|
||||
export const extractMemoIdFromName = (name: string) => {
|
||||
return name.split(memoNamePrefix).pop() || "";
|
||||
};
|
||||
|
|
|
|||
|
|
@ -175,12 +175,6 @@ export interface CreateMemoRequest {
|
|||
}
|
||||
|
||||
export interface ListMemosRequest {
|
||||
/**
|
||||
* Optional. The parent is the owner of the memos.
|
||||
* If not specified or `users/-`, it will list all memos.
|
||||
* Format: users/{user}
|
||||
*/
|
||||
parent: string;
|
||||
/**
|
||||
* Optional. The maximum number of memos to return.
|
||||
* The service may return fewer than this value.
|
||||
|
|
@ -1090,30 +1084,19 @@ export const CreateMemoRequest: MessageFns<CreateMemoRequest> = {
|
|||
};
|
||||
|
||||
function createBaseListMemosRequest(): ListMemosRequest {
|
||||
return {
|
||||
parent: "",
|
||||
pageSize: 0,
|
||||
pageToken: "",
|
||||
state: State.STATE_UNSPECIFIED,
|
||||
orderBy: "",
|
||||
filter: "",
|
||||
showDeleted: false,
|
||||
};
|
||||
return { pageSize: 0, pageToken: "", state: State.STATE_UNSPECIFIED, orderBy: "", filter: "", showDeleted: false };
|
||||
}
|
||||
|
||||
export const ListMemosRequest: MessageFns<ListMemosRequest> = {
|
||||
encode(message: ListMemosRequest, writer: BinaryWriter = new BinaryWriter()): BinaryWriter {
|
||||
if (message.parent !== "") {
|
||||
writer.uint32(10).string(message.parent);
|
||||
}
|
||||
if (message.pageSize !== 0) {
|
||||
writer.uint32(16).int32(message.pageSize);
|
||||
writer.uint32(8).int32(message.pageSize);
|
||||
}
|
||||
if (message.pageToken !== "") {
|
||||
writer.uint32(26).string(message.pageToken);
|
||||
writer.uint32(18).string(message.pageToken);
|
||||
}
|
||||
if (message.state !== State.STATE_UNSPECIFIED) {
|
||||
writer.uint32(32).int32(stateToNumber(message.state));
|
||||
writer.uint32(24).int32(stateToNumber(message.state));
|
||||
}
|
||||
if (message.orderBy !== "") {
|
||||
writer.uint32(42).string(message.orderBy);
|
||||
|
|
@ -1135,31 +1118,23 @@ export const ListMemosRequest: MessageFns<ListMemosRequest> = {
|
|||
const tag = reader.uint32();
|
||||
switch (tag >>> 3) {
|
||||
case 1: {
|
||||
if (tag !== 10) {
|
||||
break;
|
||||
}
|
||||
|
||||
message.parent = reader.string();
|
||||
continue;
|
||||
}
|
||||
case 2: {
|
||||
if (tag !== 16) {
|
||||
if (tag !== 8) {
|
||||
break;
|
||||
}
|
||||
|
||||
message.pageSize = reader.int32();
|
||||
continue;
|
||||
}
|
||||
case 3: {
|
||||
if (tag !== 26) {
|
||||
case 2: {
|
||||
if (tag !== 18) {
|
||||
break;
|
||||
}
|
||||
|
||||
message.pageToken = reader.string();
|
||||
continue;
|
||||
}
|
||||
case 4: {
|
||||
if (tag !== 32) {
|
||||
case 3: {
|
||||
if (tag !== 24) {
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
@ -1204,7 +1179,6 @@ export const ListMemosRequest: MessageFns<ListMemosRequest> = {
|
|||
},
|
||||
fromPartial(object: DeepPartial<ListMemosRequest>): ListMemosRequest {
|
||||
const message = createBaseListMemosRequest();
|
||||
message.parent = object.parent ?? "";
|
||||
message.pageSize = object.pageSize ?? 0;
|
||||
message.pageToken = object.pageToken ?? "";
|
||||
message.state = object.state ?? State.STATE_UNSPECIFIED;
|
||||
|
|
@ -2662,61 +2636,8 @@ export const MemoServiceDefinition = {
|
|||
responseStream: false,
|
||||
options: {
|
||||
_unknownFields: {
|
||||
8410: [new Uint8Array([0]), new Uint8Array([6, 112, 97, 114, 101, 110, 116])],
|
||||
578365826: [
|
||||
new Uint8Array([
|
||||
49,
|
||||
90,
|
||||
32,
|
||||
18,
|
||||
30,
|
||||
47,
|
||||
97,
|
||||
112,
|
||||
105,
|
||||
47,
|
||||
118,
|
||||
49,
|
||||
47,
|
||||
123,
|
||||
112,
|
||||
97,
|
||||
114,
|
||||
101,
|
||||
110,
|
||||
116,
|
||||
61,
|
||||
117,
|
||||
115,
|
||||
101,
|
||||
114,
|
||||
115,
|
||||
47,
|
||||
42,
|
||||
125,
|
||||
47,
|
||||
109,
|
||||
101,
|
||||
109,
|
||||
111,
|
||||
115,
|
||||
18,
|
||||
13,
|
||||
47,
|
||||
97,
|
||||
112,
|
||||
105,
|
||||
47,
|
||||
118,
|
||||
49,
|
||||
47,
|
||||
109,
|
||||
101,
|
||||
109,
|
||||
111,
|
||||
115,
|
||||
]),
|
||||
],
|
||||
8410: [new Uint8Array([0])],
|
||||
578365826: [new Uint8Array([15, 18, 13, 47, 97, 112, 105, 47, 118, 49, 47, 109, 101, 109, 111, 115])],
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
|
|||
|
|
@ -128,6 +128,52 @@ export function editionToNumber(object: Edition): number {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Describes the 'visibility' of a symbol with respect to the proto import
|
||||
* system. Symbols can only be imported when the visibility rules do not prevent
|
||||
* it (ex: local symbols cannot be imported). Visibility modifiers can only set
|
||||
* on `message` and `enum` as they are the only types available to be referenced
|
||||
* from other files.
|
||||
*/
|
||||
export enum SymbolVisibility {
|
||||
VISIBILITY_UNSET = "VISIBILITY_UNSET",
|
||||
VISIBILITY_LOCAL = "VISIBILITY_LOCAL",
|
||||
VISIBILITY_EXPORT = "VISIBILITY_EXPORT",
|
||||
UNRECOGNIZED = "UNRECOGNIZED",
|
||||
}
|
||||
|
||||
export function symbolVisibilityFromJSON(object: any): SymbolVisibility {
|
||||
switch (object) {
|
||||
case 0:
|
||||
case "VISIBILITY_UNSET":
|
||||
return SymbolVisibility.VISIBILITY_UNSET;
|
||||
case 1:
|
||||
case "VISIBILITY_LOCAL":
|
||||
return SymbolVisibility.VISIBILITY_LOCAL;
|
||||
case 2:
|
||||
case "VISIBILITY_EXPORT":
|
||||
return SymbolVisibility.VISIBILITY_EXPORT;
|
||||
case -1:
|
||||
case "UNRECOGNIZED":
|
||||
default:
|
||||
return SymbolVisibility.UNRECOGNIZED;
|
||||
}
|
||||
}
|
||||
|
||||
export function symbolVisibilityToNumber(object: SymbolVisibility): number {
|
||||
switch (object) {
|
||||
case SymbolVisibility.VISIBILITY_UNSET:
|
||||
return 0;
|
||||
case SymbolVisibility.VISIBILITY_LOCAL:
|
||||
return 1;
|
||||
case SymbolVisibility.VISIBILITY_EXPORT:
|
||||
return 2;
|
||||
case SymbolVisibility.UNRECOGNIZED:
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The protocol compiler can output a FileDescriptorSet containing the .proto
|
||||
* files it parses.
|
||||
|
|
@ -155,6 +201,11 @@ export interface FileDescriptorProto {
|
|||
* For Google-internal migration only. Do not use.
|
||||
*/
|
||||
weakDependency: number[];
|
||||
/**
|
||||
* Names of files imported by this file purely for the purpose of providing
|
||||
* option extensions. These are excluded from the dependency list above.
|
||||
*/
|
||||
optionDependency: string[];
|
||||
/** All top-level definitions in this file. */
|
||||
messageType: DescriptorProto[];
|
||||
enumType: EnumDescriptorProto[];
|
||||
|
|
@ -209,6 +260,8 @@ export interface DescriptorProto {
|
|||
* A given name may only be reserved once.
|
||||
*/
|
||||
reservedName: string[];
|
||||
/** Support for `export` and `local` keywords on enums. */
|
||||
visibility?: SymbolVisibility | undefined;
|
||||
}
|
||||
|
||||
export interface DescriptorProto_ExtensionRange {
|
||||
|
|
@ -632,6 +685,8 @@ export interface EnumDescriptorProto {
|
|||
* be reserved once.
|
||||
*/
|
||||
reservedName: string[];
|
||||
/** Support for `export` and `local` keywords on enums. */
|
||||
visibility?: SymbolVisibility | undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -1594,6 +1649,7 @@ export interface FeatureSet {
|
|||
messageEncoding?: FeatureSet_MessageEncoding | undefined;
|
||||
jsonFormat?: FeatureSet_JsonFormat | undefined;
|
||||
enforceNamingStyle?: FeatureSet_EnforceNamingStyle | undefined;
|
||||
defaultSymbolVisibility?: FeatureSet_VisibilityFeature_DefaultSymbolVisibility | undefined;
|
||||
}
|
||||
|
||||
export enum FeatureSet_FieldPresence {
|
||||
|
|
@ -1875,6 +1931,72 @@ export function featureSet_EnforceNamingStyleToNumber(object: FeatureSet_Enforce
|
|||
}
|
||||
}
|
||||
|
||||
export interface FeatureSet_VisibilityFeature {
|
||||
}
|
||||
|
||||
export enum FeatureSet_VisibilityFeature_DefaultSymbolVisibility {
|
||||
DEFAULT_SYMBOL_VISIBILITY_UNKNOWN = "DEFAULT_SYMBOL_VISIBILITY_UNKNOWN",
|
||||
/** EXPORT_ALL - Default pre-EDITION_2024, all UNSET visibility are export. */
|
||||
EXPORT_ALL = "EXPORT_ALL",
|
||||
/** EXPORT_TOP_LEVEL - All top-level symbols default to export, nested default to local. */
|
||||
EXPORT_TOP_LEVEL = "EXPORT_TOP_LEVEL",
|
||||
/** LOCAL_ALL - All symbols default to local. */
|
||||
LOCAL_ALL = "LOCAL_ALL",
|
||||
/**
|
||||
* STRICT - All symbols local by default. Nested types cannot be exported.
|
||||
* With special case caveat for message { enum {} reserved 1 to max; }
|
||||
* This is the recommended setting for new protos.
|
||||
*/
|
||||
STRICT = "STRICT",
|
||||
UNRECOGNIZED = "UNRECOGNIZED",
|
||||
}
|
||||
|
||||
export function featureSet_VisibilityFeature_DefaultSymbolVisibilityFromJSON(
|
||||
object: any,
|
||||
): FeatureSet_VisibilityFeature_DefaultSymbolVisibility {
|
||||
switch (object) {
|
||||
case 0:
|
||||
case "DEFAULT_SYMBOL_VISIBILITY_UNKNOWN":
|
||||
return FeatureSet_VisibilityFeature_DefaultSymbolVisibility.DEFAULT_SYMBOL_VISIBILITY_UNKNOWN;
|
||||
case 1:
|
||||
case "EXPORT_ALL":
|
||||
return FeatureSet_VisibilityFeature_DefaultSymbolVisibility.EXPORT_ALL;
|
||||
case 2:
|
||||
case "EXPORT_TOP_LEVEL":
|
||||
return FeatureSet_VisibilityFeature_DefaultSymbolVisibility.EXPORT_TOP_LEVEL;
|
||||
case 3:
|
||||
case "LOCAL_ALL":
|
||||
return FeatureSet_VisibilityFeature_DefaultSymbolVisibility.LOCAL_ALL;
|
||||
case 4:
|
||||
case "STRICT":
|
||||
return FeatureSet_VisibilityFeature_DefaultSymbolVisibility.STRICT;
|
||||
case -1:
|
||||
case "UNRECOGNIZED":
|
||||
default:
|
||||
return FeatureSet_VisibilityFeature_DefaultSymbolVisibility.UNRECOGNIZED;
|
||||
}
|
||||
}
|
||||
|
||||
export function featureSet_VisibilityFeature_DefaultSymbolVisibilityToNumber(
|
||||
object: FeatureSet_VisibilityFeature_DefaultSymbolVisibility,
|
||||
): number {
|
||||
switch (object) {
|
||||
case FeatureSet_VisibilityFeature_DefaultSymbolVisibility.DEFAULT_SYMBOL_VISIBILITY_UNKNOWN:
|
||||
return 0;
|
||||
case FeatureSet_VisibilityFeature_DefaultSymbolVisibility.EXPORT_ALL:
|
||||
return 1;
|
||||
case FeatureSet_VisibilityFeature_DefaultSymbolVisibility.EXPORT_TOP_LEVEL:
|
||||
return 2;
|
||||
case FeatureSet_VisibilityFeature_DefaultSymbolVisibility.LOCAL_ALL:
|
||||
return 3;
|
||||
case FeatureSet_VisibilityFeature_DefaultSymbolVisibility.STRICT:
|
||||
return 4;
|
||||
case FeatureSet_VisibilityFeature_DefaultSymbolVisibility.UNRECOGNIZED:
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A compiled specification for the defaults of a set of features. These
|
||||
* messages are generated from FeatureSet extensions and can be used to seed
|
||||
|
|
@ -2195,6 +2317,7 @@ function createBaseFileDescriptorProto(): FileDescriptorProto {
|
|||
dependency: [],
|
||||
publicDependency: [],
|
||||
weakDependency: [],
|
||||
optionDependency: [],
|
||||
messageType: [],
|
||||
enumType: [],
|
||||
service: [],
|
||||
|
|
@ -2227,6 +2350,9 @@ export const FileDescriptorProto: MessageFns<FileDescriptorProto> = {
|
|||
writer.int32(v);
|
||||
}
|
||||
writer.join();
|
||||
for (const v of message.optionDependency) {
|
||||
writer.uint32(122).string(v!);
|
||||
}
|
||||
for (const v of message.messageType) {
|
||||
DescriptorProto.encode(v!, writer.uint32(34).fork()).join();
|
||||
}
|
||||
|
|
@ -2321,6 +2447,14 @@ export const FileDescriptorProto: MessageFns<FileDescriptorProto> = {
|
|||
|
||||
break;
|
||||
}
|
||||
case 15: {
|
||||
if (tag !== 122) {
|
||||
break;
|
||||
}
|
||||
|
||||
message.optionDependency.push(reader.string());
|
||||
continue;
|
||||
}
|
||||
case 4: {
|
||||
if (tag !== 34) {
|
||||
break;
|
||||
|
|
@ -2404,6 +2538,7 @@ export const FileDescriptorProto: MessageFns<FileDescriptorProto> = {
|
|||
message.dependency = object.dependency?.map((e) => e) || [];
|
||||
message.publicDependency = object.publicDependency?.map((e) => e) || [];
|
||||
message.weakDependency = object.weakDependency?.map((e) => e) || [];
|
||||
message.optionDependency = object.optionDependency?.map((e) => e) || [];
|
||||
message.messageType = object.messageType?.map((e) => DescriptorProto.fromPartial(e)) || [];
|
||||
message.enumType = object.enumType?.map((e) => EnumDescriptorProto.fromPartial(e)) || [];
|
||||
message.service = object.service?.map((e) => ServiceDescriptorProto.fromPartial(e)) || [];
|
||||
|
|
@ -2432,6 +2567,7 @@ function createBaseDescriptorProto(): DescriptorProto {
|
|||
options: undefined,
|
||||
reservedRange: [],
|
||||
reservedName: [],
|
||||
visibility: SymbolVisibility.VISIBILITY_UNSET,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -2467,6 +2603,9 @@ export const DescriptorProto: MessageFns<DescriptorProto> = {
|
|||
for (const v of message.reservedName) {
|
||||
writer.uint32(82).string(v!);
|
||||
}
|
||||
if (message.visibility !== undefined && message.visibility !== SymbolVisibility.VISIBILITY_UNSET) {
|
||||
writer.uint32(88).int32(symbolVisibilityToNumber(message.visibility));
|
||||
}
|
||||
return writer;
|
||||
},
|
||||
|
||||
|
|
@ -2557,6 +2696,14 @@ export const DescriptorProto: MessageFns<DescriptorProto> = {
|
|||
message.reservedName.push(reader.string());
|
||||
continue;
|
||||
}
|
||||
case 11: {
|
||||
if (tag !== 88) {
|
||||
break;
|
||||
}
|
||||
|
||||
message.visibility = symbolVisibilityFromJSON(reader.int32());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if ((tag & 7) === 4 || tag === 0) {
|
||||
break;
|
||||
|
|
@ -2583,6 +2730,7 @@ export const DescriptorProto: MessageFns<DescriptorProto> = {
|
|||
: undefined;
|
||||
message.reservedRange = object.reservedRange?.map((e) => DescriptorProto_ReservedRange.fromPartial(e)) || [];
|
||||
message.reservedName = object.reservedName?.map((e) => e) || [];
|
||||
message.visibility = object.visibility ?? SymbolVisibility.VISIBILITY_UNSET;
|
||||
return message;
|
||||
},
|
||||
};
|
||||
|
|
@ -3143,7 +3291,14 @@ export const OneofDescriptorProto: MessageFns<OneofDescriptorProto> = {
|
|||
};
|
||||
|
||||
function createBaseEnumDescriptorProto(): EnumDescriptorProto {
|
||||
return { name: "", value: [], options: undefined, reservedRange: [], reservedName: [] };
|
||||
return {
|
||||
name: "",
|
||||
value: [],
|
||||
options: undefined,
|
||||
reservedRange: [],
|
||||
reservedName: [],
|
||||
visibility: SymbolVisibility.VISIBILITY_UNSET,
|
||||
};
|
||||
}
|
||||
|
||||
export const EnumDescriptorProto: MessageFns<EnumDescriptorProto> = {
|
||||
|
|
@ -3163,6 +3318,9 @@ export const EnumDescriptorProto: MessageFns<EnumDescriptorProto> = {
|
|||
for (const v of message.reservedName) {
|
||||
writer.uint32(42).string(v!);
|
||||
}
|
||||
if (message.visibility !== undefined && message.visibility !== SymbolVisibility.VISIBILITY_UNSET) {
|
||||
writer.uint32(48).int32(symbolVisibilityToNumber(message.visibility));
|
||||
}
|
||||
return writer;
|
||||
},
|
||||
|
||||
|
|
@ -3213,6 +3371,14 @@ export const EnumDescriptorProto: MessageFns<EnumDescriptorProto> = {
|
|||
message.reservedName.push(reader.string());
|
||||
continue;
|
||||
}
|
||||
case 6: {
|
||||
if (tag !== 48) {
|
||||
break;
|
||||
}
|
||||
|
||||
message.visibility = symbolVisibilityFromJSON(reader.int32());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if ((tag & 7) === 4 || tag === 0) {
|
||||
break;
|
||||
|
|
@ -3235,6 +3401,7 @@ export const EnumDescriptorProto: MessageFns<EnumDescriptorProto> = {
|
|||
message.reservedRange = object.reservedRange?.map((e) => EnumDescriptorProto_EnumReservedRange.fromPartial(e)) ||
|
||||
[];
|
||||
message.reservedName = object.reservedName?.map((e) => e) || [];
|
||||
message.visibility = object.visibility ?? SymbolVisibility.VISIBILITY_UNSET;
|
||||
return message;
|
||||
},
|
||||
};
|
||||
|
|
@ -4999,6 +5166,7 @@ function createBaseFeatureSet(): FeatureSet {
|
|||
messageEncoding: FeatureSet_MessageEncoding.MESSAGE_ENCODING_UNKNOWN,
|
||||
jsonFormat: FeatureSet_JsonFormat.JSON_FORMAT_UNKNOWN,
|
||||
enforceNamingStyle: FeatureSet_EnforceNamingStyle.ENFORCE_NAMING_STYLE_UNKNOWN,
|
||||
defaultSymbolVisibility: FeatureSet_VisibilityFeature_DefaultSymbolVisibility.DEFAULT_SYMBOL_VISIBILITY_UNKNOWN,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -5039,6 +5207,15 @@ export const FeatureSet: MessageFns<FeatureSet> = {
|
|||
) {
|
||||
writer.uint32(56).int32(featureSet_EnforceNamingStyleToNumber(message.enforceNamingStyle));
|
||||
}
|
||||
if (
|
||||
message.defaultSymbolVisibility !== undefined &&
|
||||
message.defaultSymbolVisibility !==
|
||||
FeatureSet_VisibilityFeature_DefaultSymbolVisibility.DEFAULT_SYMBOL_VISIBILITY_UNKNOWN
|
||||
) {
|
||||
writer.uint32(64).int32(
|
||||
featureSet_VisibilityFeature_DefaultSymbolVisibilityToNumber(message.defaultSymbolVisibility),
|
||||
);
|
||||
}
|
||||
return writer;
|
||||
},
|
||||
|
||||
|
|
@ -5105,6 +5282,16 @@ export const FeatureSet: MessageFns<FeatureSet> = {
|
|||
message.enforceNamingStyle = featureSet_EnforceNamingStyleFromJSON(reader.int32());
|
||||
continue;
|
||||
}
|
||||
case 8: {
|
||||
if (tag !== 64) {
|
||||
break;
|
||||
}
|
||||
|
||||
message.defaultSymbolVisibility = featureSet_VisibilityFeature_DefaultSymbolVisibilityFromJSON(
|
||||
reader.int32(),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if ((tag & 7) === 4 || tag === 0) {
|
||||
break;
|
||||
|
|
@ -5128,6 +5315,42 @@ export const FeatureSet: MessageFns<FeatureSet> = {
|
|||
message.jsonFormat = object.jsonFormat ?? FeatureSet_JsonFormat.JSON_FORMAT_UNKNOWN;
|
||||
message.enforceNamingStyle = object.enforceNamingStyle ??
|
||||
FeatureSet_EnforceNamingStyle.ENFORCE_NAMING_STYLE_UNKNOWN;
|
||||
message.defaultSymbolVisibility = object.defaultSymbolVisibility ??
|
||||
FeatureSet_VisibilityFeature_DefaultSymbolVisibility.DEFAULT_SYMBOL_VISIBILITY_UNKNOWN;
|
||||
return message;
|
||||
},
|
||||
};
|
||||
|
||||
function createBaseFeatureSet_VisibilityFeature(): FeatureSet_VisibilityFeature {
|
||||
return {};
|
||||
}
|
||||
|
||||
export const FeatureSet_VisibilityFeature: MessageFns<FeatureSet_VisibilityFeature> = {
|
||||
encode(_: FeatureSet_VisibilityFeature, writer: BinaryWriter = new BinaryWriter()): BinaryWriter {
|
||||
return writer;
|
||||
},
|
||||
|
||||
decode(input: BinaryReader | Uint8Array, length?: number): FeatureSet_VisibilityFeature {
|
||||
const reader = input instanceof BinaryReader ? input : new BinaryReader(input);
|
||||
let end = length === undefined ? reader.len : reader.pos + length;
|
||||
const message = createBaseFeatureSet_VisibilityFeature();
|
||||
while (reader.pos < end) {
|
||||
const tag = reader.uint32();
|
||||
switch (tag >>> 3) {
|
||||
}
|
||||
if ((tag & 7) === 4 || tag === 0) {
|
||||
break;
|
||||
}
|
||||
reader.skip(tag & 7);
|
||||
}
|
||||
return message;
|
||||
},
|
||||
|
||||
create(base?: DeepPartial<FeatureSet_VisibilityFeature>): FeatureSet_VisibilityFeature {
|
||||
return FeatureSet_VisibilityFeature.fromPartial(base ?? {});
|
||||
},
|
||||
fromPartial(_: DeepPartial<FeatureSet_VisibilityFeature>): FeatureSet_VisibilityFeature {
|
||||
const message = createBaseFeatureSet_VisibilityFeature();
|
||||
return message;
|
||||
},
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in New Issue