init push

This commit is contained in:
2026-05-21 19:52:45 +08:00
commit e3f75311ab
1280 changed files with 179173 additions and 0 deletions

329
backend/repo/pg/auth.go Normal file
View File

@@ -0,0 +1,329 @@
package pg
import (
"context"
"errors"
"fmt"
"time"
"github.com/samber/lo"
"gorm.io/gorm"
"github.com/chaitin/panda-wiki/consts"
"github.com/chaitin/panda-wiki/domain"
"github.com/chaitin/panda-wiki/log"
"github.com/chaitin/panda-wiki/store/cache"
"github.com/chaitin/panda-wiki/store/pg"
)
type AuthRepo struct {
db *pg.DB
logger *log.Logger
cache *cache.Cache
}
func NewAuthRepo(db *pg.DB, logger *log.Logger, cache *cache.Cache) *AuthRepo {
return &AuthRepo{
db: db,
logger: logger,
cache: cache,
}
}
func (r *AuthRepo) GetAuthUserinfoByIDs(ctx context.Context, authIDs []uint) (map[uint]*domain.AuthInfo, error) {
if len(authIDs) == 0 {
return nil, nil
}
var authUserInfo = []domain.AuthInfo{}
err := r.db.WithContext(ctx).Table("auths").
Select("id,user_info as auth_user_info").
Where("id IN (?) ", authIDs).
Where("source_type NOT IN (?)", consts.BotSourceTypes).
Find(&authUserInfo).Error
if err != nil {
return nil, err
}
//set map
result := make(map[uint]*domain.AuthInfo, 0)
for _, a := range authUserInfo {
result[a.ID] = &a
}
return result, nil
}
func (r *AuthRepo) GetAuthGroupByAuthId(ctx context.Context, authID uint) ([]domain.AuthGroup, error) {
authGroups := make([]domain.AuthGroup, 0)
err := r.db.WithContext(ctx).Model(&domain.AuthGroup{}).
Where("? = ANY(auth_ids)", authID).
Find(&authGroups).Error
if err != nil {
return nil, err
}
return authGroups, nil
}
// getAllAuthGroupsAsMap fetches all auth groups and returns them as a map for quick lookup
func (r *AuthRepo) getAllAuthGroupsAsMap(ctx context.Context) (map[uint]*domain.AuthGroup, error) {
var allGroups []domain.AuthGroup
err := r.db.WithContext(ctx).Find(&allGroups).Error
if err != nil {
return nil, err
}
groupMap := lo.SliceToMap(allGroups, func(group domain.AuthGroup) (uint, *domain.AuthGroup) {
return group.ID, &group
})
return groupMap, nil
}
// getAuthGroupsWithParentsByAuthId is a helper method that retrieves user's auth groups and all parent groups
func (r *AuthRepo) getAuthGroupsWithParentsByAuthId(ctx context.Context, authID uint) (map[uint]domain.AuthGroup, error) {
// Get user's direct auth groups
var directGroups []domain.AuthGroup
err := r.db.WithContext(ctx).Model(&domain.AuthGroup{}).
Where("? = ANY(auth_ids)", authID).
Find(&directGroups).Error
if err != nil {
return nil, err
}
if len(directGroups) == 0 {
return make(map[uint]domain.AuthGroup), nil
}
groupMap, err := r.getAllAuthGroupsAsMap(ctx)
if err != nil {
return nil, err
}
resultGroups := make(map[uint]domain.AuthGroup)
visited := make(map[uint]bool)
var findParents func(uint)
findParents = func(groupID uint) {
if visited[groupID] {
return // Avoid circular reference
}
visited[groupID] = true
group, exists := groupMap[groupID]
if !exists {
return // Group not found, end search
}
resultGroups[group.ID] = *group
if group.ParentID != nil {
findParents(*group.ParentID)
}
}
// Process user's direct groups and their parent groups
for _, group := range directGroups {
resultGroups[group.ID] = group
if group.ParentID != nil {
findParents(*group.ParentID)
}
}
return resultGroups, nil
}
// GetAuthGroupWithParentsByAuthId retrieves user's auth groups and all parent groups as slice
func (r *AuthRepo) GetAuthGroupWithParentsByAuthId(ctx context.Context, authID uint) ([]domain.AuthGroup, error) {
groupsMap, err := r.getAuthGroupsWithParentsByAuthId(ctx, authID)
if err != nil {
return nil, err
}
result := make([]domain.AuthGroup, 0, len(groupsMap))
for _, group := range groupsMap {
result = append(result, group)
}
return result, nil
}
func (r *AuthRepo) GetAuthGroupIdsByAuthId(ctx context.Context, authID uint) ([]int, error) {
groupIds := make([]int, 0)
err := r.db.WithContext(ctx).Model(&domain.AuthGroup{}).
Where("? = ANY(auth_ids)", authID).
Pluck("id", &groupIds).Error
if err != nil {
return nil, err
}
return groupIds, nil
}
// GetAuthGroupIdsWithParentsByAuthId retrieves user's auth group IDs and all parent group IDs (for permission inheritance)
func (r *AuthRepo) GetAuthGroupIdsWithParentsByAuthId(ctx context.Context, authID uint) ([]int, error) {
groupsMap, err := r.getAuthGroupsWithParentsByAuthId(ctx, authID)
if err != nil {
return nil, err
}
result := make([]int, 0, len(groupsMap))
for _, group := range groupsMap {
result = append(result, int(group.ID))
}
return result, nil
}
func (r *AuthRepo) GetAuthBySourceType(ctx context.Context, sourceType consts.SourceType) (*domain.Auth, error) {
var auth *domain.Auth
if err := r.db.WithContext(ctx).Model(&domain.Auth{}).Where("source_type = ?", string(sourceType)).First(&auth).Error; err != nil {
return nil, err
}
return auth, nil
}
func (r *AuthRepo) GetAuthByKBIDAndSourceType(ctx context.Context, kbID string, sourceType consts.SourceType) (*domain.Auth, error) {
var auth *domain.Auth
if err := r.db.WithContext(ctx).Model(&domain.Auth{}).Where("kb_id = ? AND source_type = ?", kbID, string(sourceType)).First(&auth).Error; err != nil {
return nil, err
}
return auth, nil
}
func (r *AuthRepo) CreateAuth(ctx context.Context, auth *domain.Auth) error {
return r.db.WithContext(ctx).Model(&domain.Auth{}).Create(auth).Error
}
func (r *AuthRepo) DeleteAuth(ctx context.Context, kbID string, authId int64) error {
return r.db.WithContext(ctx).Where("kb_id = ? and id = ?", kbID, authId).Delete(&domain.Auth{}).Error
}
func (r *AuthRepo) CreateAuthConfig(ctx context.Context, authConfig *domain.AuthConfig) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
var existing domain.AuthConfig
err := tx.Model(&domain.AuthConfig{}).
Where("kb_id = ?", authConfig.KbID).
Where("source_type = ?", authConfig.SourceType).
First(&existing).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
if err := tx.Model(&domain.AuthConfig{}).
Create(authConfig).Error; err != nil {
return err
}
return nil
}
return err
}
// 已存在则更新
if err := tx.Model(&domain.AuthConfig{}).
Where("kb_id = ?", authConfig.KbID).
Where("source_type = ?", authConfig.SourceType).
Updates(authConfig).Error; err != nil {
return err
}
return nil
})
}
func (r *AuthRepo) GetAuthById(ctx context.Context, kbID string, id uint) (*domain.Auth, error) {
var auth domain.Auth
if err := r.db.WithContext(ctx).
Model(&domain.Auth{}).
Where("kb_id = ?", kbID).
Where("id = ?", id).
First(&auth).Error; err != nil {
return nil, err
}
return &auth, nil
}
func (r *AuthRepo) GetAuthConfig(ctx context.Context, kbID string, sourceType consts.SourceType) (*domain.AuthConfig, error) {
var authConfig domain.AuthConfig
if err := r.db.WithContext(ctx).
Model(&domain.AuthConfig{}).
Where("kb_id = ?", kbID).
Where("source_type = ?", string(sourceType)).
Order("created_at DESC").
Limit(1).
First(&authConfig).Error; err != nil {
return nil, err
}
return &authConfig, nil
}
func (r *AuthRepo) GetAuths(ctx context.Context, kbID string, sourceType consts.SourceType) ([]domain.Auth, error) {
auths := make([]domain.Auth, 0)
if err := r.db.WithContext(ctx).
Model(&domain.Auth{}).
Where("kb_id = ?", kbID).
Where("source_type in (?)", append(consts.BotSourceTypes, sourceType)).
Order("last_login_time DESC").
Find(&auths).Error; err != nil {
return nil, err
}
return auths, nil
}
func (r *AuthRepo) GetOrCreateAuth(ctx context.Context, auth *domain.Auth, sourceType consts.SourceType) (*domain.Auth, error) {
if err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
var existing domain.Auth
err := tx.Model(&domain.Auth{}).
Where("kb_id = ?", auth.KBID).
Where("source_type = ?", auth.SourceType).
Where("union_id = ?", auth.UnionID).
First(&existing).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
var count int64
// 统计时排除机器人类型的认证机器人不占用license限制名额
if err := tx.Model(&domain.Auth{}).
Where("kb_id = ?", auth.KBID).
Where("source_type NOT IN (?)", consts.BotSourceTypes).
Count(&count).Error; err != nil {
return err
}
if int(count) >= domain.GetBaseEditionLimitation(ctx).MaxSSOUser {
return fmt.Errorf("exceed max auth limit for kb %s, current count: %d, max limit: %d", auth.KBID, count, domain.GetBaseEditionLimitation(ctx).MaxSSOUser)
}
auth.LastLoginTime = time.Now()
if err := tx.Model(&domain.Auth{}).Create(auth).Error; err != nil {
return err
}
return nil
}
return err
}
updateMap := map[string]interface{}{
"last_login_time": time.Now(),
"user_info": auth.UserInfo,
}
if err := tx.Model(&domain.Auth{}).Where("id = ?", existing.ID).Updates(updateMap).Error; err != nil {
return err
}
return nil
}); err != nil {
return nil, err
}
err := r.db.Model(&domain.Auth{}).
Where("kb_id = ?", auth.KBID).
Where("source_type = ?", auth.SourceType).
Where("union_id = ?", auth.UnionID).
First(&auth).Error
if err != nil {
return nil, err
}
return auth, nil
}