mirror of https://github.com/usememos/memos.git
513 lines
15 KiB
Go
513 lines
15 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
|
|
"github.com/pkg/errors"
|
|
"google.golang.org/protobuf/encoding/protojson"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
|
|
storepb "github.com/usememos/memos/proto/gen/store"
|
|
)
|
|
|
|
type UserSetting struct {
|
|
UserID int32
|
|
Key storepb.UserSetting_Key
|
|
Value string
|
|
}
|
|
|
|
type FindUserSetting struct {
|
|
UserID *int32
|
|
Key storepb.UserSetting_Key
|
|
}
|
|
|
|
// RefreshTokenQueryResult contains the result of querying a refresh token.
|
|
type RefreshTokenQueryResult struct {
|
|
UserID int32
|
|
RefreshToken *storepb.RefreshTokensUserSetting_RefreshToken
|
|
}
|
|
|
|
// PATQueryResult contains the result of querying a PAT by hash.
|
|
type PATQueryResult struct {
|
|
UserID int32
|
|
User *User
|
|
PAT *storepb.PersonalAccessTokensUserSetting_PersonalAccessToken
|
|
}
|
|
|
|
func (s *Store) UpsertUserSetting(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) {
|
|
userSettingRaw, err := convertUserSettingToRaw(upsert)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
userSettingRaw, err = s.driver.UpsertUserSetting(ctx, userSettingRaw)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
userSetting, err := convertUserSettingFromRaw(userSettingRaw)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if userSetting == nil {
|
|
return nil, errors.New("unexpected nil user setting")
|
|
}
|
|
s.userSettingCache.Set(ctx, getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting)
|
|
return userSetting, nil
|
|
}
|
|
|
|
func (s *Store) ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*storepb.UserSetting, error) {
|
|
userSettingRawList, err := s.driver.ListUserSettings(ctx, find)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
userSettings := []*storepb.UserSetting{}
|
|
for _, userSettingRaw := range userSettingRawList {
|
|
userSetting, err := convertUserSettingFromRaw(userSettingRaw)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if userSetting == nil {
|
|
continue
|
|
}
|
|
s.userSettingCache.Set(ctx, getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting)
|
|
userSettings = append(userSettings, userSetting)
|
|
}
|
|
return userSettings, nil
|
|
}
|
|
|
|
func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*storepb.UserSetting, error) {
|
|
if find.UserID != nil {
|
|
if cache, ok := s.userSettingCache.Get(ctx, getUserSettingCacheKey(*find.UserID, find.Key.String())); ok {
|
|
userSetting, ok := cache.(*storepb.UserSetting)
|
|
if ok {
|
|
return userSetting, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
list, err := s.ListUserSettings(ctx, find)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(list) == 0 {
|
|
return nil, nil
|
|
}
|
|
if len(list) > 1 {
|
|
return nil, errors.Errorf("expected 1 user setting, but got %d", len(list))
|
|
}
|
|
|
|
userSetting := list[0]
|
|
s.userSettingCache.Set(ctx, getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting)
|
|
return userSetting, nil
|
|
}
|
|
|
|
// GetUserByPATHash finds a user by PAT hash.
|
|
func (s *Store) GetUserByPATHash(ctx context.Context, tokenHash string) (*PATQueryResult, error) {
|
|
result, err := s.driver.GetUserByPATHash(ctx, tokenHash)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Fetch user info
|
|
user, err := s.GetUser(ctx, &FindUser{ID: &result.UserID})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if user == nil {
|
|
return nil, errors.New("user not found for PAT")
|
|
}
|
|
result.User = user
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// GetUserRefreshTokens returns the refresh tokens of the user.
|
|
func (s *Store) GetUserRefreshTokens(ctx context.Context, userID int32) ([]*storepb.RefreshTokensUserSetting_RefreshToken, error) {
|
|
userSetting, err := s.GetUserSetting(ctx, &FindUserSetting{
|
|
UserID: &userID,
|
|
Key: storepb.UserSetting_REFRESH_TOKENS,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if userSetting == nil {
|
|
return []*storepb.RefreshTokensUserSetting_RefreshToken{}, nil
|
|
}
|
|
return userSetting.GetRefreshTokens().RefreshTokens, nil
|
|
}
|
|
|
|
// AddUserRefreshToken adds a new refresh token for the user.
|
|
func (s *Store) AddUserRefreshToken(ctx context.Context, userID int32, token *storepb.RefreshTokensUserSetting_RefreshToken) error {
|
|
tokens, err := s.GetUserRefreshTokens(ctx, userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
tokens = append(tokens, token)
|
|
|
|
_, err = s.UpsertUserSetting(ctx, &storepb.UserSetting{
|
|
UserId: userID,
|
|
Key: storepb.UserSetting_REFRESH_TOKENS,
|
|
Value: &storepb.UserSetting_RefreshTokens{
|
|
RefreshTokens: &storepb.RefreshTokensUserSetting{
|
|
RefreshTokens: tokens,
|
|
},
|
|
},
|
|
})
|
|
return err
|
|
}
|
|
|
|
// RemoveUserRefreshToken removes a refresh token from the user.
|
|
func (s *Store) RemoveUserRefreshToken(ctx context.Context, userID int32, tokenID string) error {
|
|
existingTokens, err := s.GetUserRefreshTokens(ctx, userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
newTokens := make([]*storepb.RefreshTokensUserSetting_RefreshToken, 0, len(existingTokens))
|
|
for _, token := range existingTokens {
|
|
if token.TokenId != tokenID {
|
|
newTokens = append(newTokens, token)
|
|
}
|
|
}
|
|
|
|
_, err = s.UpsertUserSetting(ctx, &storepb.UserSetting{
|
|
UserId: userID,
|
|
Key: storepb.UserSetting_REFRESH_TOKENS,
|
|
Value: &storepb.UserSetting_RefreshTokens{
|
|
RefreshTokens: &storepb.RefreshTokensUserSetting{
|
|
RefreshTokens: newTokens,
|
|
},
|
|
},
|
|
})
|
|
return err
|
|
}
|
|
|
|
// GetUserRefreshTokenByID returns a specific refresh token.
|
|
func (s *Store) GetUserRefreshTokenByID(ctx context.Context, userID int32, tokenID string) (*storepb.RefreshTokensUserSetting_RefreshToken, error) {
|
|
tokens, err := s.GetUserRefreshTokens(ctx, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for _, token := range tokens {
|
|
if token.TokenId == tokenID {
|
|
return token, nil
|
|
}
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
// GetUserPersonalAccessTokens returns the PATs of the user.
|
|
func (s *Store) GetUserPersonalAccessTokens(ctx context.Context, userID int32) ([]*storepb.PersonalAccessTokensUserSetting_PersonalAccessToken, error) {
|
|
userSetting, err := s.GetUserSetting(ctx, &FindUserSetting{
|
|
UserID: &userID,
|
|
Key: storepb.UserSetting_PERSONAL_ACCESS_TOKENS,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if userSetting == nil {
|
|
return []*storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{}, nil
|
|
}
|
|
return userSetting.GetPersonalAccessTokens().Tokens, nil
|
|
}
|
|
|
|
// AddUserPersonalAccessToken adds a new PAT for the user.
|
|
func (s *Store) AddUserPersonalAccessToken(ctx context.Context, userID int32, token *storepb.PersonalAccessTokensUserSetting_PersonalAccessToken) error {
|
|
tokens, err := s.GetUserPersonalAccessTokens(ctx, userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
tokens = append(tokens, token)
|
|
|
|
_, err = s.UpsertUserSetting(ctx, &storepb.UserSetting{
|
|
UserId: userID,
|
|
Key: storepb.UserSetting_PERSONAL_ACCESS_TOKENS,
|
|
Value: &storepb.UserSetting_PersonalAccessTokens{
|
|
PersonalAccessTokens: &storepb.PersonalAccessTokensUserSetting{
|
|
Tokens: tokens,
|
|
},
|
|
},
|
|
})
|
|
return err
|
|
}
|
|
|
|
// RemoveUserPersonalAccessToken removes a PAT from the user.
|
|
func (s *Store) RemoveUserPersonalAccessToken(ctx context.Context, userID int32, tokenID string) error {
|
|
existingTokens, err := s.GetUserPersonalAccessTokens(ctx, userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
newTokens := make([]*storepb.PersonalAccessTokensUserSetting_PersonalAccessToken, 0, len(existingTokens))
|
|
for _, token := range existingTokens {
|
|
if token.TokenId != tokenID {
|
|
newTokens = append(newTokens, token)
|
|
}
|
|
}
|
|
|
|
_, err = s.UpsertUserSetting(ctx, &storepb.UserSetting{
|
|
UserId: userID,
|
|
Key: storepb.UserSetting_PERSONAL_ACCESS_TOKENS,
|
|
Value: &storepb.UserSetting_PersonalAccessTokens{
|
|
PersonalAccessTokens: &storepb.PersonalAccessTokensUserSetting{
|
|
Tokens: newTokens,
|
|
},
|
|
},
|
|
})
|
|
return err
|
|
}
|
|
|
|
// UpdatePATLastUsed updates the last_used_at timestamp of a PAT.
|
|
func (s *Store) UpdatePATLastUsed(ctx context.Context, userID int32, tokenID string, lastUsed *timestamppb.Timestamp) error {
|
|
tokens, err := s.GetUserPersonalAccessTokens(ctx, userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, token := range tokens {
|
|
if token.TokenId == tokenID {
|
|
token.LastUsedAt = lastUsed
|
|
break
|
|
}
|
|
}
|
|
|
|
_, err = s.UpsertUserSetting(ctx, &storepb.UserSetting{
|
|
UserId: userID,
|
|
Key: storepb.UserSetting_PERSONAL_ACCESS_TOKENS,
|
|
Value: &storepb.UserSetting_PersonalAccessTokens{
|
|
PersonalAccessTokens: &storepb.PersonalAccessTokensUserSetting{
|
|
Tokens: tokens,
|
|
},
|
|
},
|
|
})
|
|
return err
|
|
}
|
|
|
|
// GetUserWebhooks returns the webhooks of the user.
|
|
func (s *Store) GetUserWebhooks(ctx context.Context, userID int32) ([]*storepb.WebhooksUserSetting_Webhook, error) {
|
|
userSetting, err := s.GetUserSetting(ctx, &FindUserSetting{
|
|
UserID: &userID,
|
|
Key: storepb.UserSetting_WEBHOOKS,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if userSetting == nil {
|
|
return []*storepb.WebhooksUserSetting_Webhook{}, nil
|
|
}
|
|
|
|
webhooksUserSetting := userSetting.GetWebhooks()
|
|
return webhooksUserSetting.Webhooks, nil
|
|
}
|
|
|
|
// AddUserWebhook adds a new webhook for the user.
|
|
func (s *Store) AddUserWebhook(ctx context.Context, userID int32, webhook *storepb.WebhooksUserSetting_Webhook) error {
|
|
existingWebhooks, err := s.GetUserWebhooks(ctx, userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Check if webhook already exists, update if it does
|
|
var updatedWebhooks []*storepb.WebhooksUserSetting_Webhook
|
|
webhookExists := false
|
|
for _, existing := range existingWebhooks {
|
|
if existing.Id == webhook.Id {
|
|
updatedWebhooks = append(updatedWebhooks, webhook)
|
|
webhookExists = true
|
|
} else {
|
|
updatedWebhooks = append(updatedWebhooks, existing)
|
|
}
|
|
}
|
|
|
|
// If webhook doesn't exist, add it
|
|
if !webhookExists {
|
|
updatedWebhooks = append(updatedWebhooks, webhook)
|
|
}
|
|
|
|
_, err = s.UpsertUserSetting(ctx, &storepb.UserSetting{
|
|
UserId: userID,
|
|
Key: storepb.UserSetting_WEBHOOKS,
|
|
Value: &storepb.UserSetting_Webhooks{
|
|
Webhooks: &storepb.WebhooksUserSetting{
|
|
Webhooks: updatedWebhooks,
|
|
},
|
|
},
|
|
})
|
|
|
|
return err
|
|
}
|
|
|
|
// RemoveUserWebhook removes the webhook of the user.
|
|
func (s *Store) RemoveUserWebhook(ctx context.Context, userID int32, webhookID string) error {
|
|
oldWebhooks, err := s.GetUserWebhooks(ctx, userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
newWebhooks := make([]*storepb.WebhooksUserSetting_Webhook, 0, len(oldWebhooks))
|
|
for _, webhook := range oldWebhooks {
|
|
if webhookID != webhook.Id {
|
|
newWebhooks = append(newWebhooks, webhook)
|
|
}
|
|
}
|
|
|
|
_, err = s.UpsertUserSetting(ctx, &storepb.UserSetting{
|
|
UserId: userID,
|
|
Key: storepb.UserSetting_WEBHOOKS,
|
|
Value: &storepb.UserSetting_Webhooks{
|
|
Webhooks: &storepb.WebhooksUserSetting{
|
|
Webhooks: newWebhooks,
|
|
},
|
|
},
|
|
})
|
|
|
|
return err
|
|
}
|
|
|
|
// UpdateUserWebhook updates an existing webhook for the user.
|
|
func (s *Store) UpdateUserWebhook(ctx context.Context, userID int32, webhook *storepb.WebhooksUserSetting_Webhook) error {
|
|
webhooks, err := s.GetUserWebhooks(ctx, userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for i, existing := range webhooks {
|
|
if existing.Id == webhook.Id {
|
|
webhooks[i] = webhook
|
|
break
|
|
}
|
|
}
|
|
|
|
_, err = s.UpsertUserSetting(ctx, &storepb.UserSetting{
|
|
UserId: userID,
|
|
Key: storepb.UserSetting_WEBHOOKS,
|
|
Value: &storepb.UserSetting_Webhooks{
|
|
Webhooks: &storepb.WebhooksUserSetting{
|
|
Webhooks: webhooks,
|
|
},
|
|
},
|
|
})
|
|
|
|
return err
|
|
}
|
|
|
|
func convertUserSettingFromRaw(raw *UserSetting) (*storepb.UserSetting, error) {
|
|
userSetting := &storepb.UserSetting{
|
|
UserId: raw.UserID,
|
|
Key: raw.Key,
|
|
}
|
|
|
|
switch raw.Key {
|
|
case storepb.UserSetting_ACCESS_TOKENS:
|
|
accessTokensUserSetting := &storepb.AccessTokensUserSetting{}
|
|
if err := protojsonUnmarshaler.Unmarshal([]byte(raw.Value), accessTokensUserSetting); err != nil {
|
|
return nil, err
|
|
}
|
|
userSetting.Value = &storepb.UserSetting_AccessTokens{AccessTokens: accessTokensUserSetting}
|
|
case storepb.UserSetting_SESSIONS:
|
|
sessionsUserSetting := &storepb.SessionsUserSetting{}
|
|
if err := protojsonUnmarshaler.Unmarshal([]byte(raw.Value), sessionsUserSetting); err != nil {
|
|
return nil, err
|
|
}
|
|
userSetting.Value = &storepb.UserSetting_Sessions{Sessions: sessionsUserSetting}
|
|
case storepb.UserSetting_SHORTCUTS:
|
|
shortcutsUserSetting := &storepb.ShortcutsUserSetting{}
|
|
if err := protojsonUnmarshaler.Unmarshal([]byte(raw.Value), shortcutsUserSetting); err != nil {
|
|
return nil, err
|
|
}
|
|
userSetting.Value = &storepb.UserSetting_Shortcuts{Shortcuts: shortcutsUserSetting}
|
|
case storepb.UserSetting_GENERAL:
|
|
generalUserSetting := &storepb.GeneralUserSetting{}
|
|
if err := protojsonUnmarshaler.Unmarshal([]byte(raw.Value), generalUserSetting); err != nil {
|
|
return nil, err
|
|
}
|
|
userSetting.Value = &storepb.UserSetting_General{General: generalUserSetting}
|
|
case storepb.UserSetting_REFRESH_TOKENS:
|
|
refreshTokensUserSetting := &storepb.RefreshTokensUserSetting{}
|
|
if err := protojsonUnmarshaler.Unmarshal([]byte(raw.Value), refreshTokensUserSetting); err != nil {
|
|
return nil, err
|
|
}
|
|
userSetting.Value = &storepb.UserSetting_RefreshTokens{RefreshTokens: refreshTokensUserSetting}
|
|
case storepb.UserSetting_PERSONAL_ACCESS_TOKENS:
|
|
patsUserSetting := &storepb.PersonalAccessTokensUserSetting{}
|
|
if err := protojsonUnmarshaler.Unmarshal([]byte(raw.Value), patsUserSetting); err != nil {
|
|
return nil, err
|
|
}
|
|
userSetting.Value = &storepb.UserSetting_PersonalAccessTokens{PersonalAccessTokens: patsUserSetting}
|
|
case storepb.UserSetting_WEBHOOKS:
|
|
webhooksUserSetting := &storepb.WebhooksUserSetting{}
|
|
if err := protojsonUnmarshaler.Unmarshal([]byte(raw.Value), webhooksUserSetting); err != nil {
|
|
return nil, err
|
|
}
|
|
userSetting.Value = &storepb.UserSetting_Webhooks{Webhooks: webhooksUserSetting}
|
|
default:
|
|
return nil, nil
|
|
}
|
|
return userSetting, nil
|
|
}
|
|
|
|
func convertUserSettingToRaw(userSetting *storepb.UserSetting) (*UserSetting, error) {
|
|
raw := &UserSetting{
|
|
UserID: userSetting.UserId,
|
|
Key: userSetting.Key,
|
|
}
|
|
|
|
switch userSetting.Key {
|
|
case storepb.UserSetting_ACCESS_TOKENS:
|
|
accessTokensUserSetting := userSetting.GetAccessTokens()
|
|
value, err := protojson.Marshal(accessTokensUserSetting)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
raw.Value = string(value)
|
|
case storepb.UserSetting_SESSIONS:
|
|
sessionsUserSetting := userSetting.GetSessions()
|
|
value, err := protojson.Marshal(sessionsUserSetting)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
raw.Value = string(value)
|
|
case storepb.UserSetting_SHORTCUTS:
|
|
shortcutsUserSetting := userSetting.GetShortcuts()
|
|
value, err := protojson.Marshal(shortcutsUserSetting)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
raw.Value = string(value)
|
|
case storepb.UserSetting_GENERAL:
|
|
generalUserSetting := userSetting.GetGeneral()
|
|
value, err := protojson.Marshal(generalUserSetting)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
raw.Value = string(value)
|
|
case storepb.UserSetting_REFRESH_TOKENS:
|
|
refreshTokensUserSetting := userSetting.GetRefreshTokens()
|
|
value, err := protojson.Marshal(refreshTokensUserSetting)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
raw.Value = string(value)
|
|
case storepb.UserSetting_PERSONAL_ACCESS_TOKENS:
|
|
patsUserSetting := userSetting.GetPersonalAccessTokens()
|
|
value, err := protojson.Marshal(patsUserSetting)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
raw.Value = string(value)
|
|
case storepb.UserSetting_WEBHOOKS:
|
|
webhooksUserSetting := userSetting.GetWebhooks()
|
|
value, err := protojson.Marshal(webhooksUserSetting)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
raw.Value = string(value)
|
|
default:
|
|
return nil, errors.Errorf("unsupported user setting key: %v", userSetting.Key)
|
|
}
|
|
return raw, nil
|
|
}
|