init push
This commit is contained in:
329
backend/repo/pg/auth.go
Normal file
329
backend/repo/pg/auth.go
Normal file
@@ -0,0 +1,329 @@
|
||||
package pg
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/samber/lo"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/chaitin/panda-wiki/consts"
|
||||
"github.com/chaitin/panda-wiki/domain"
|
||||
"github.com/chaitin/panda-wiki/log"
|
||||
"github.com/chaitin/panda-wiki/store/cache"
|
||||
"github.com/chaitin/panda-wiki/store/pg"
|
||||
)
|
||||
|
||||
type AuthRepo struct {
|
||||
db *pg.DB
|
||||
logger *log.Logger
|
||||
cache *cache.Cache
|
||||
}
|
||||
|
||||
func NewAuthRepo(db *pg.DB, logger *log.Logger, cache *cache.Cache) *AuthRepo {
|
||||
return &AuthRepo{
|
||||
db: db,
|
||||
logger: logger,
|
||||
cache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *AuthRepo) GetAuthUserinfoByIDs(ctx context.Context, authIDs []uint) (map[uint]*domain.AuthInfo, error) {
|
||||
if len(authIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var authUserInfo = []domain.AuthInfo{}
|
||||
err := r.db.WithContext(ctx).Table("auths").
|
||||
Select("id,user_info as auth_user_info").
|
||||
Where("id IN (?) ", authIDs).
|
||||
Where("source_type NOT IN (?)", consts.BotSourceTypes).
|
||||
Find(&authUserInfo).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
//set map
|
||||
result := make(map[uint]*domain.AuthInfo, 0)
|
||||
for _, a := range authUserInfo {
|
||||
result[a.ID] = &a
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *AuthRepo) GetAuthGroupByAuthId(ctx context.Context, authID uint) ([]domain.AuthGroup, error) {
|
||||
authGroups := make([]domain.AuthGroup, 0)
|
||||
err := r.db.WithContext(ctx).Model(&domain.AuthGroup{}).
|
||||
Where("? = ANY(auth_ids)", authID).
|
||||
Find(&authGroups).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return authGroups, nil
|
||||
}
|
||||
|
||||
// getAllAuthGroupsAsMap fetches all auth groups and returns them as a map for quick lookup
|
||||
func (r *AuthRepo) getAllAuthGroupsAsMap(ctx context.Context) (map[uint]*domain.AuthGroup, error) {
|
||||
var allGroups []domain.AuthGroup
|
||||
err := r.db.WithContext(ctx).Find(&allGroups).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupMap := lo.SliceToMap(allGroups, func(group domain.AuthGroup) (uint, *domain.AuthGroup) {
|
||||
return group.ID, &group
|
||||
})
|
||||
|
||||
return groupMap, nil
|
||||
}
|
||||
|
||||
// getAuthGroupsWithParentsByAuthId is a helper method that retrieves user's auth groups and all parent groups
|
||||
func (r *AuthRepo) getAuthGroupsWithParentsByAuthId(ctx context.Context, authID uint) (map[uint]domain.AuthGroup, error) {
|
||||
// Get user's direct auth groups
|
||||
var directGroups []domain.AuthGroup
|
||||
err := r.db.WithContext(ctx).Model(&domain.AuthGroup{}).
|
||||
Where("? = ANY(auth_ids)", authID).
|
||||
Find(&directGroups).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(directGroups) == 0 {
|
||||
return make(map[uint]domain.AuthGroup), nil
|
||||
}
|
||||
|
||||
groupMap, err := r.getAllAuthGroupsAsMap(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resultGroups := make(map[uint]domain.AuthGroup)
|
||||
visited := make(map[uint]bool)
|
||||
|
||||
var findParents func(uint)
|
||||
findParents = func(groupID uint) {
|
||||
if visited[groupID] {
|
||||
return // Avoid circular reference
|
||||
}
|
||||
visited[groupID] = true
|
||||
|
||||
group, exists := groupMap[groupID]
|
||||
if !exists {
|
||||
return // Group not found, end search
|
||||
}
|
||||
|
||||
resultGroups[group.ID] = *group
|
||||
|
||||
if group.ParentID != nil {
|
||||
findParents(*group.ParentID)
|
||||
}
|
||||
}
|
||||
|
||||
// Process user's direct groups and their parent groups
|
||||
for _, group := range directGroups {
|
||||
resultGroups[group.ID] = group
|
||||
if group.ParentID != nil {
|
||||
findParents(*group.ParentID)
|
||||
}
|
||||
}
|
||||
|
||||
return resultGroups, nil
|
||||
}
|
||||
|
||||
// GetAuthGroupWithParentsByAuthId retrieves user's auth groups and all parent groups as slice
|
||||
func (r *AuthRepo) GetAuthGroupWithParentsByAuthId(ctx context.Context, authID uint) ([]domain.AuthGroup, error) {
|
||||
groupsMap, err := r.getAuthGroupsWithParentsByAuthId(ctx, authID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]domain.AuthGroup, 0, len(groupsMap))
|
||||
for _, group := range groupsMap {
|
||||
result = append(result, group)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *AuthRepo) GetAuthGroupIdsByAuthId(ctx context.Context, authID uint) ([]int, error) {
|
||||
groupIds := make([]int, 0)
|
||||
err := r.db.WithContext(ctx).Model(&domain.AuthGroup{}).
|
||||
Where("? = ANY(auth_ids)", authID).
|
||||
Pluck("id", &groupIds).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return groupIds, nil
|
||||
}
|
||||
|
||||
// GetAuthGroupIdsWithParentsByAuthId retrieves user's auth group IDs and all parent group IDs (for permission inheritance)
|
||||
func (r *AuthRepo) GetAuthGroupIdsWithParentsByAuthId(ctx context.Context, authID uint) ([]int, error) {
|
||||
groupsMap, err := r.getAuthGroupsWithParentsByAuthId(ctx, authID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]int, 0, len(groupsMap))
|
||||
for _, group := range groupsMap {
|
||||
result = append(result, int(group.ID))
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *AuthRepo) GetAuthBySourceType(ctx context.Context, sourceType consts.SourceType) (*domain.Auth, error) {
|
||||
var auth *domain.Auth
|
||||
if err := r.db.WithContext(ctx).Model(&domain.Auth{}).Where("source_type = ?", string(sourceType)).First(&auth).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (r *AuthRepo) GetAuthByKBIDAndSourceType(ctx context.Context, kbID string, sourceType consts.SourceType) (*domain.Auth, error) {
|
||||
var auth *domain.Auth
|
||||
if err := r.db.WithContext(ctx).Model(&domain.Auth{}).Where("kb_id = ? AND source_type = ?", kbID, string(sourceType)).First(&auth).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (r *AuthRepo) CreateAuth(ctx context.Context, auth *domain.Auth) error {
|
||||
return r.db.WithContext(ctx).Model(&domain.Auth{}).Create(auth).Error
|
||||
}
|
||||
|
||||
func (r *AuthRepo) DeleteAuth(ctx context.Context, kbID string, authId int64) error {
|
||||
return r.db.WithContext(ctx).Where("kb_id = ? and id = ?", kbID, authId).Delete(&domain.Auth{}).Error
|
||||
}
|
||||
|
||||
func (r *AuthRepo) CreateAuthConfig(ctx context.Context, authConfig *domain.AuthConfig) error {
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
var existing domain.AuthConfig
|
||||
err := tx.Model(&domain.AuthConfig{}).
|
||||
Where("kb_id = ?", authConfig.KbID).
|
||||
Where("source_type = ?", authConfig.SourceType).
|
||||
First(&existing).Error
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
if err := tx.Model(&domain.AuthConfig{}).
|
||||
Create(authConfig).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 已存在则更新
|
||||
if err := tx.Model(&domain.AuthConfig{}).
|
||||
Where("kb_id = ?", authConfig.KbID).
|
||||
Where("source_type = ?", authConfig.SourceType).
|
||||
Updates(authConfig).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (r *AuthRepo) GetAuthById(ctx context.Context, kbID string, id uint) (*domain.Auth, error) {
|
||||
var auth domain.Auth
|
||||
if err := r.db.WithContext(ctx).
|
||||
Model(&domain.Auth{}).
|
||||
Where("kb_id = ?", kbID).
|
||||
Where("id = ?", id).
|
||||
First(&auth).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &auth, nil
|
||||
}
|
||||
|
||||
func (r *AuthRepo) GetAuthConfig(ctx context.Context, kbID string, sourceType consts.SourceType) (*domain.AuthConfig, error) {
|
||||
var authConfig domain.AuthConfig
|
||||
|
||||
if err := r.db.WithContext(ctx).
|
||||
Model(&domain.AuthConfig{}).
|
||||
Where("kb_id = ?", kbID).
|
||||
Where("source_type = ?", string(sourceType)).
|
||||
Order("created_at DESC").
|
||||
Limit(1).
|
||||
First(&authConfig).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &authConfig, nil
|
||||
}
|
||||
|
||||
func (r *AuthRepo) GetAuths(ctx context.Context, kbID string, sourceType consts.SourceType) ([]domain.Auth, error) {
|
||||
auths := make([]domain.Auth, 0)
|
||||
|
||||
if err := r.db.WithContext(ctx).
|
||||
Model(&domain.Auth{}).
|
||||
Where("kb_id = ?", kbID).
|
||||
Where("source_type in (?)", append(consts.BotSourceTypes, sourceType)).
|
||||
Order("last_login_time DESC").
|
||||
Find(&auths).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return auths, nil
|
||||
}
|
||||
|
||||
func (r *AuthRepo) GetOrCreateAuth(ctx context.Context, auth *domain.Auth, sourceType consts.SourceType) (*domain.Auth, error) {
|
||||
|
||||
if err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
var existing domain.Auth
|
||||
err := tx.Model(&domain.Auth{}).
|
||||
Where("kb_id = ?", auth.KBID).
|
||||
Where("source_type = ?", auth.SourceType).
|
||||
Where("union_id = ?", auth.UnionID).
|
||||
First(&existing).Error
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
var count int64
|
||||
// 统计时排除机器人类型的认证,机器人不占用license限制名额
|
||||
if err := tx.Model(&domain.Auth{}).
|
||||
Where("kb_id = ?", auth.KBID).
|
||||
Where("source_type NOT IN (?)", consts.BotSourceTypes).
|
||||
Count(&count).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if int(count) >= domain.GetBaseEditionLimitation(ctx).MaxSSOUser {
|
||||
return fmt.Errorf("exceed max auth limit for kb %s, current count: %d, max limit: %d", auth.KBID, count, domain.GetBaseEditionLimitation(ctx).MaxSSOUser)
|
||||
}
|
||||
|
||||
auth.LastLoginTime = time.Now()
|
||||
if err := tx.Model(&domain.Auth{}).Create(auth).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
updateMap := map[string]interface{}{
|
||||
"last_login_time": time.Now(),
|
||||
"user_info": auth.UserInfo,
|
||||
}
|
||||
if err := tx.Model(&domain.Auth{}).Where("id = ?", existing.ID).Updates(updateMap).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err := r.db.Model(&domain.Auth{}).
|
||||
Where("kb_id = ?", auth.KBID).
|
||||
Where("source_type = ?", auth.SourceType).
|
||||
Where("union_id = ?", auth.UnionID).
|
||||
First(&auth).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return auth, nil
|
||||
}
|
||||
Reference in New Issue
Block a user