package pg import ( "context" "errors" "fmt" "time" "github.com/samber/lo" "gorm.io/gorm" "github.com/chaitin/panda-wiki/consts" "github.com/chaitin/panda-wiki/domain" "github.com/chaitin/panda-wiki/log" "github.com/chaitin/panda-wiki/store/cache" "github.com/chaitin/panda-wiki/store/pg" ) type AuthRepo struct { db *pg.DB logger *log.Logger cache *cache.Cache } func NewAuthRepo(db *pg.DB, logger *log.Logger, cache *cache.Cache) *AuthRepo { return &AuthRepo{ db: db, logger: logger, cache: cache, } } func (r *AuthRepo) GetAuthUserinfoByIDs(ctx context.Context, authIDs []uint) (map[uint]*domain.AuthInfo, error) { if len(authIDs) == 0 { return nil, nil } var authUserInfo = []domain.AuthInfo{} err := r.db.WithContext(ctx).Table("auths"). Select("id,user_info as auth_user_info"). Where("id IN (?) ", authIDs). Where("source_type NOT IN (?)", consts.BotSourceTypes). Find(&authUserInfo).Error if err != nil { return nil, err } //set map result := make(map[uint]*domain.AuthInfo, 0) for _, a := range authUserInfo { result[a.ID] = &a } return result, nil } func (r *AuthRepo) GetAuthGroupByAuthId(ctx context.Context, authID uint) ([]domain.AuthGroup, error) { authGroups := make([]domain.AuthGroup, 0) err := r.db.WithContext(ctx).Model(&domain.AuthGroup{}). Where("? = ANY(auth_ids)", authID). Find(&authGroups).Error if err != nil { return nil, err } return authGroups, nil } // getAllAuthGroupsAsMap fetches all auth groups and returns them as a map for quick lookup func (r *AuthRepo) getAllAuthGroupsAsMap(ctx context.Context) (map[uint]*domain.AuthGroup, error) { var allGroups []domain.AuthGroup err := r.db.WithContext(ctx).Find(&allGroups).Error if err != nil { return nil, err } groupMap := lo.SliceToMap(allGroups, func(group domain.AuthGroup) (uint, *domain.AuthGroup) { return group.ID, &group }) return groupMap, nil } // getAuthGroupsWithParentsByAuthId is a helper method that retrieves user's auth groups and all parent groups func (r *AuthRepo) getAuthGroupsWithParentsByAuthId(ctx context.Context, authID uint) (map[uint]domain.AuthGroup, error) { // Get user's direct auth groups var directGroups []domain.AuthGroup err := r.db.WithContext(ctx).Model(&domain.AuthGroup{}). Where("? = ANY(auth_ids)", authID). Find(&directGroups).Error if err != nil { return nil, err } if len(directGroups) == 0 { return make(map[uint]domain.AuthGroup), nil } groupMap, err := r.getAllAuthGroupsAsMap(ctx) if err != nil { return nil, err } resultGroups := make(map[uint]domain.AuthGroup) visited := make(map[uint]bool) var findParents func(uint) findParents = func(groupID uint) { if visited[groupID] { return // Avoid circular reference } visited[groupID] = true group, exists := groupMap[groupID] if !exists { return // Group not found, end search } resultGroups[group.ID] = *group if group.ParentID != nil { findParents(*group.ParentID) } } // Process user's direct groups and their parent groups for _, group := range directGroups { resultGroups[group.ID] = group if group.ParentID != nil { findParents(*group.ParentID) } } return resultGroups, nil } // GetAuthGroupWithParentsByAuthId retrieves user's auth groups and all parent groups as slice func (r *AuthRepo) GetAuthGroupWithParentsByAuthId(ctx context.Context, authID uint) ([]domain.AuthGroup, error) { groupsMap, err := r.getAuthGroupsWithParentsByAuthId(ctx, authID) if err != nil { return nil, err } result := make([]domain.AuthGroup, 0, len(groupsMap)) for _, group := range groupsMap { result = append(result, group) } return result, nil } func (r *AuthRepo) GetAuthGroupIdsByAuthId(ctx context.Context, authID uint) ([]int, error) { groupIds := make([]int, 0) err := r.db.WithContext(ctx).Model(&domain.AuthGroup{}). Where("? = ANY(auth_ids)", authID). Pluck("id", &groupIds).Error if err != nil { return nil, err } return groupIds, nil } // GetAuthGroupIdsWithParentsByAuthId retrieves user's auth group IDs and all parent group IDs (for permission inheritance) func (r *AuthRepo) GetAuthGroupIdsWithParentsByAuthId(ctx context.Context, authID uint) ([]int, error) { groupsMap, err := r.getAuthGroupsWithParentsByAuthId(ctx, authID) if err != nil { return nil, err } result := make([]int, 0, len(groupsMap)) for _, group := range groupsMap { result = append(result, int(group.ID)) } return result, nil } func (r *AuthRepo) GetAuthBySourceType(ctx context.Context, sourceType consts.SourceType) (*domain.Auth, error) { var auth *domain.Auth if err := r.db.WithContext(ctx).Model(&domain.Auth{}).Where("source_type = ?", string(sourceType)).First(&auth).Error; err != nil { return nil, err } return auth, nil } func (r *AuthRepo) GetAuthByKBIDAndSourceType(ctx context.Context, kbID string, sourceType consts.SourceType) (*domain.Auth, error) { var auth *domain.Auth if err := r.db.WithContext(ctx).Model(&domain.Auth{}).Where("kb_id = ? AND source_type = ?", kbID, string(sourceType)).First(&auth).Error; err != nil { return nil, err } return auth, nil } func (r *AuthRepo) CreateAuth(ctx context.Context, auth *domain.Auth) error { return r.db.WithContext(ctx).Model(&domain.Auth{}).Create(auth).Error } func (r *AuthRepo) DeleteAuth(ctx context.Context, kbID string, authId int64) error { return r.db.WithContext(ctx).Where("kb_id = ? and id = ?", kbID, authId).Delete(&domain.Auth{}).Error } func (r *AuthRepo) CreateAuthConfig(ctx context.Context, authConfig *domain.AuthConfig) error { return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { var existing domain.AuthConfig err := tx.Model(&domain.AuthConfig{}). Where("kb_id = ?", authConfig.KbID). Where("source_type = ?", authConfig.SourceType). First(&existing).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { if err := tx.Model(&domain.AuthConfig{}). Create(authConfig).Error; err != nil { return err } return nil } return err } // 已存在则更新 if err := tx.Model(&domain.AuthConfig{}). Where("kb_id = ?", authConfig.KbID). Where("source_type = ?", authConfig.SourceType). Updates(authConfig).Error; err != nil { return err } return nil }) } func (r *AuthRepo) GetAuthById(ctx context.Context, kbID string, id uint) (*domain.Auth, error) { var auth domain.Auth if err := r.db.WithContext(ctx). Model(&domain.Auth{}). Where("kb_id = ?", kbID). Where("id = ?", id). First(&auth).Error; err != nil { return nil, err } return &auth, nil } func (r *AuthRepo) GetAuthConfig(ctx context.Context, kbID string, sourceType consts.SourceType) (*domain.AuthConfig, error) { var authConfig domain.AuthConfig if err := r.db.WithContext(ctx). Model(&domain.AuthConfig{}). Where("kb_id = ?", kbID). Where("source_type = ?", string(sourceType)). Order("created_at DESC"). Limit(1). First(&authConfig).Error; err != nil { return nil, err } return &authConfig, nil } func (r *AuthRepo) GetAuths(ctx context.Context, kbID string, sourceType consts.SourceType) ([]domain.Auth, error) { auths := make([]domain.Auth, 0) if err := r.db.WithContext(ctx). Model(&domain.Auth{}). Where("kb_id = ?", kbID). Where("source_type in (?)", append(consts.BotSourceTypes, sourceType)). Order("last_login_time DESC"). Find(&auths).Error; err != nil { return nil, err } return auths, nil } func (r *AuthRepo) GetOrCreateAuth(ctx context.Context, auth *domain.Auth, sourceType consts.SourceType) (*domain.Auth, error) { if err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { var existing domain.Auth err := tx.Model(&domain.Auth{}). Where("kb_id = ?", auth.KBID). Where("source_type = ?", auth.SourceType). Where("union_id = ?", auth.UnionID). First(&existing).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { var count int64 // 统计时排除机器人类型的认证,机器人不占用license限制名额 if err := tx.Model(&domain.Auth{}). Where("kb_id = ?", auth.KBID). Where("source_type NOT IN (?)", consts.BotSourceTypes). Count(&count).Error; err != nil { return err } if int(count) >= domain.GetBaseEditionLimitation(ctx).MaxSSOUser { return fmt.Errorf("exceed max auth limit for kb %s, current count: %d, max limit: %d", auth.KBID, count, domain.GetBaseEditionLimitation(ctx).MaxSSOUser) } auth.LastLoginTime = time.Now() if err := tx.Model(&domain.Auth{}).Create(auth).Error; err != nil { return err } return nil } return err } updateMap := map[string]interface{}{ "last_login_time": time.Now(), "user_info": auth.UserInfo, } if err := tx.Model(&domain.Auth{}).Where("id = ?", existing.ID).Updates(updateMap).Error; err != nil { return err } return nil }); err != nil { return nil, err } err := r.db.Model(&domain.Auth{}). Where("kb_id = ?", auth.KBID). Where("source_type = ?", auth.SourceType). Where("union_id = ?", auth.UnionID). First(&auth).Error if err != nil { return nil, err } return auth, nil }