Files
YouduWiki/backend/repo/pg/auth.go
2026-05-21 19:52:45 +08:00

330 lines
9.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package pg
import (
"context"
"errors"
"fmt"
"time"
"github.com/samber/lo"
"gorm.io/gorm"
"github.com/chaitin/panda-wiki/consts"
"github.com/chaitin/panda-wiki/domain"
"github.com/chaitin/panda-wiki/log"
"github.com/chaitin/panda-wiki/store/cache"
"github.com/chaitin/panda-wiki/store/pg"
)
type AuthRepo struct {
db *pg.DB
logger *log.Logger
cache *cache.Cache
}
func NewAuthRepo(db *pg.DB, logger *log.Logger, cache *cache.Cache) *AuthRepo {
return &AuthRepo{
db: db,
logger: logger,
cache: cache,
}
}
func (r *AuthRepo) GetAuthUserinfoByIDs(ctx context.Context, authIDs []uint) (map[uint]*domain.AuthInfo, error) {
if len(authIDs) == 0 {
return nil, nil
}
var authUserInfo = []domain.AuthInfo{}
err := r.db.WithContext(ctx).Table("auths").
Select("id,user_info as auth_user_info").
Where("id IN (?) ", authIDs).
Where("source_type NOT IN (?)", consts.BotSourceTypes).
Find(&authUserInfo).Error
if err != nil {
return nil, err
}
//set map
result := make(map[uint]*domain.AuthInfo, 0)
for _, a := range authUserInfo {
result[a.ID] = &a
}
return result, nil
}
func (r *AuthRepo) GetAuthGroupByAuthId(ctx context.Context, authID uint) ([]domain.AuthGroup, error) {
authGroups := make([]domain.AuthGroup, 0)
err := r.db.WithContext(ctx).Model(&domain.AuthGroup{}).
Where("? = ANY(auth_ids)", authID).
Find(&authGroups).Error
if err != nil {
return nil, err
}
return authGroups, nil
}
// getAllAuthGroupsAsMap fetches all auth groups and returns them as a map for quick lookup
func (r *AuthRepo) getAllAuthGroupsAsMap(ctx context.Context) (map[uint]*domain.AuthGroup, error) {
var allGroups []domain.AuthGroup
err := r.db.WithContext(ctx).Find(&allGroups).Error
if err != nil {
return nil, err
}
groupMap := lo.SliceToMap(allGroups, func(group domain.AuthGroup) (uint, *domain.AuthGroup) {
return group.ID, &group
})
return groupMap, nil
}
// getAuthGroupsWithParentsByAuthId is a helper method that retrieves user's auth groups and all parent groups
func (r *AuthRepo) getAuthGroupsWithParentsByAuthId(ctx context.Context, authID uint) (map[uint]domain.AuthGroup, error) {
// Get user's direct auth groups
var directGroups []domain.AuthGroup
err := r.db.WithContext(ctx).Model(&domain.AuthGroup{}).
Where("? = ANY(auth_ids)", authID).
Find(&directGroups).Error
if err != nil {
return nil, err
}
if len(directGroups) == 0 {
return make(map[uint]domain.AuthGroup), nil
}
groupMap, err := r.getAllAuthGroupsAsMap(ctx)
if err != nil {
return nil, err
}
resultGroups := make(map[uint]domain.AuthGroup)
visited := make(map[uint]bool)
var findParents func(uint)
findParents = func(groupID uint) {
if visited[groupID] {
return // Avoid circular reference
}
visited[groupID] = true
group, exists := groupMap[groupID]
if !exists {
return // Group not found, end search
}
resultGroups[group.ID] = *group
if group.ParentID != nil {
findParents(*group.ParentID)
}
}
// Process user's direct groups and their parent groups
for _, group := range directGroups {
resultGroups[group.ID] = group
if group.ParentID != nil {
findParents(*group.ParentID)
}
}
return resultGroups, nil
}
// GetAuthGroupWithParentsByAuthId retrieves user's auth groups and all parent groups as slice
func (r *AuthRepo) GetAuthGroupWithParentsByAuthId(ctx context.Context, authID uint) ([]domain.AuthGroup, error) {
groupsMap, err := r.getAuthGroupsWithParentsByAuthId(ctx, authID)
if err != nil {
return nil, err
}
result := make([]domain.AuthGroup, 0, len(groupsMap))
for _, group := range groupsMap {
result = append(result, group)
}
return result, nil
}
func (r *AuthRepo) GetAuthGroupIdsByAuthId(ctx context.Context, authID uint) ([]int, error) {
groupIds := make([]int, 0)
err := r.db.WithContext(ctx).Model(&domain.AuthGroup{}).
Where("? = ANY(auth_ids)", authID).
Pluck("id", &groupIds).Error
if err != nil {
return nil, err
}
return groupIds, nil
}
// GetAuthGroupIdsWithParentsByAuthId retrieves user's auth group IDs and all parent group IDs (for permission inheritance)
func (r *AuthRepo) GetAuthGroupIdsWithParentsByAuthId(ctx context.Context, authID uint) ([]int, error) {
groupsMap, err := r.getAuthGroupsWithParentsByAuthId(ctx, authID)
if err != nil {
return nil, err
}
result := make([]int, 0, len(groupsMap))
for _, group := range groupsMap {
result = append(result, int(group.ID))
}
return result, nil
}
func (r *AuthRepo) GetAuthBySourceType(ctx context.Context, sourceType consts.SourceType) (*domain.Auth, error) {
var auth *domain.Auth
if err := r.db.WithContext(ctx).Model(&domain.Auth{}).Where("source_type = ?", string(sourceType)).First(&auth).Error; err != nil {
return nil, err
}
return auth, nil
}
func (r *AuthRepo) GetAuthByKBIDAndSourceType(ctx context.Context, kbID string, sourceType consts.SourceType) (*domain.Auth, error) {
var auth *domain.Auth
if err := r.db.WithContext(ctx).Model(&domain.Auth{}).Where("kb_id = ? AND source_type = ?", kbID, string(sourceType)).First(&auth).Error; err != nil {
return nil, err
}
return auth, nil
}
func (r *AuthRepo) CreateAuth(ctx context.Context, auth *domain.Auth) error {
return r.db.WithContext(ctx).Model(&domain.Auth{}).Create(auth).Error
}
func (r *AuthRepo) DeleteAuth(ctx context.Context, kbID string, authId int64) error {
return r.db.WithContext(ctx).Where("kb_id = ? and id = ?", kbID, authId).Delete(&domain.Auth{}).Error
}
func (r *AuthRepo) CreateAuthConfig(ctx context.Context, authConfig *domain.AuthConfig) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
var existing domain.AuthConfig
err := tx.Model(&domain.AuthConfig{}).
Where("kb_id = ?", authConfig.KbID).
Where("source_type = ?", authConfig.SourceType).
First(&existing).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
if err := tx.Model(&domain.AuthConfig{}).
Create(authConfig).Error; err != nil {
return err
}
return nil
}
return err
}
// 已存在则更新
if err := tx.Model(&domain.AuthConfig{}).
Where("kb_id = ?", authConfig.KbID).
Where("source_type = ?", authConfig.SourceType).
Updates(authConfig).Error; err != nil {
return err
}
return nil
})
}
func (r *AuthRepo) GetAuthById(ctx context.Context, kbID string, id uint) (*domain.Auth, error) {
var auth domain.Auth
if err := r.db.WithContext(ctx).
Model(&domain.Auth{}).
Where("kb_id = ?", kbID).
Where("id = ?", id).
First(&auth).Error; err != nil {
return nil, err
}
return &auth, nil
}
func (r *AuthRepo) GetAuthConfig(ctx context.Context, kbID string, sourceType consts.SourceType) (*domain.AuthConfig, error) {
var authConfig domain.AuthConfig
if err := r.db.WithContext(ctx).
Model(&domain.AuthConfig{}).
Where("kb_id = ?", kbID).
Where("source_type = ?", string(sourceType)).
Order("created_at DESC").
Limit(1).
First(&authConfig).Error; err != nil {
return nil, err
}
return &authConfig, nil
}
func (r *AuthRepo) GetAuths(ctx context.Context, kbID string, sourceType consts.SourceType) ([]domain.Auth, error) {
auths := make([]domain.Auth, 0)
if err := r.db.WithContext(ctx).
Model(&domain.Auth{}).
Where("kb_id = ?", kbID).
Where("source_type in (?)", append(consts.BotSourceTypes, sourceType)).
Order("last_login_time DESC").
Find(&auths).Error; err != nil {
return nil, err
}
return auths, nil
}
func (r *AuthRepo) GetOrCreateAuth(ctx context.Context, auth *domain.Auth, sourceType consts.SourceType) (*domain.Auth, error) {
if err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
var existing domain.Auth
err := tx.Model(&domain.Auth{}).
Where("kb_id = ?", auth.KBID).
Where("source_type = ?", auth.SourceType).
Where("union_id = ?", auth.UnionID).
First(&existing).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
var count int64
// 统计时排除机器人类型的认证机器人不占用license限制名额
if err := tx.Model(&domain.Auth{}).
Where("kb_id = ?", auth.KBID).
Where("source_type NOT IN (?)", consts.BotSourceTypes).
Count(&count).Error; err != nil {
return err
}
if int(count) >= domain.GetBaseEditionLimitation(ctx).MaxSSOUser {
return fmt.Errorf("exceed max auth limit for kb %s, current count: %d, max limit: %d", auth.KBID, count, domain.GetBaseEditionLimitation(ctx).MaxSSOUser)
}
auth.LastLoginTime = time.Now()
if err := tx.Model(&domain.Auth{}).Create(auth).Error; err != nil {
return err
}
return nil
}
return err
}
updateMap := map[string]interface{}{
"last_login_time": time.Now(),
"user_info": auth.UserInfo,
}
if err := tx.Model(&domain.Auth{}).Where("id = ?", existing.ID).Updates(updateMap).Error; err != nil {
return err
}
return nil
}); err != nil {
return nil, err
}
err := r.db.Model(&domain.Auth{}).
Where("kb_id = ?", auth.KBID).
Where("source_type = ?", auth.SourceType).
Where("union_id = ?", auth.UnionID).
First(&auth).Error
if err != nil {
return nil, err
}
return auth, nil
}