Files
YouduWiki/backend/repo/pg/user.go
2026-05-21 19:52:45 +08:00

147 lines
4.2 KiB
Go

package pg
import (
"context"
"errors"
"fmt"
"github.com/samber/lo"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
v1 "github.com/chaitin/panda-wiki/api/user/v1"
"github.com/chaitin/panda-wiki/consts"
"github.com/chaitin/panda-wiki/domain"
"github.com/chaitin/panda-wiki/log"
"github.com/chaitin/panda-wiki/store/pg"
)
type UserRepository struct {
db *pg.DB
logger *log.Logger
}
func NewUserRepository(db *pg.DB, logger *log.Logger) *UserRepository {
return &UserRepository{
db: db,
logger: logger.WithModule("repo.pg.user"),
}
}
func (r *UserRepository) UpsertDefaultUser(ctx context.Context, user *domain.User) error {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash password: %w", err)
}
user.Password = string(hashedPassword)
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// First try to find existing user
var existingUser domain.User
err := tx.Where("account = ?", user.Account).First(&existingUser).Error
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
// User doesn't exist, create new user
if err := tx.Create(user).Error; err != nil {
return err
}
return nil
}
// User exists, update password
return tx.Model(&existingUser).Update("password", user.Password).Error
})
}
func (r *UserRepository) CreateUser(ctx context.Context, user *domain.User, edition consts.LicenseEdition) error {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash password: %w", err)
}
user.Password = string(hashedPassword)
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
var count int64
if err := tx.Model(&domain.User{}).Count(&count).Error; err != nil {
return err
}
if count >= domain.GetBaseEditionLimitation(ctx).MaxAdmin {
return fmt.Errorf("exceed max admin limit, current count: %d, max limit: %d", count, domain.GetBaseEditionLimitation(ctx).MaxAdmin)
}
if err := tx.Create(user).Error; err != nil {
return err
}
return nil
})
}
func (r *UserRepository) VerifyUser(ctx context.Context, account string, password string) (*domain.User, error) {
var user domain.User
err := r.db.WithContext(ctx).Where("account = ?", account).First(&user).Error
if err != nil {
return nil, err
}
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil {
return nil, errors.New("invalid password")
}
return &user, nil
}
func (r *UserRepository) GetUser(ctx context.Context, userID string) (*domain.User, error) {
var user domain.User
err := r.db.WithContext(ctx).
Where("id = ?", userID).
First(&user).Error
if err != nil {
return nil, err
}
return &user, nil
}
func (r *UserRepository) ListUsers(ctx context.Context) ([]v1.UserListItemResp, error) {
var users []v1.UserListItemResp
err := r.db.WithContext(ctx).
Model(&domain.User{}).
Order("created_at DESC").
Find(&users).Error
if err != nil {
return nil, err
}
return users, nil
}
func (r *UserRepository) GetUsersAccountMap(ctx context.Context) (map[string]string, error) {
var users []v1.UserListItemResp
err := r.db.WithContext(ctx).
Model(&domain.User{}).
Find(&users).Error
if err != nil {
return nil, err
}
m := lo.SliceToMap(users, func(user v1.UserListItemResp) (string, string) {
return user.ID, user.Account
})
return m, nil
}
func (r *UserRepository) UpdateUserPassword(ctx context.Context, userID string, newPassword string) error {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash password: %w", err)
}
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", userID).Update("password", string(hashedPassword)).Error
}
func (r *UserRepository) DeleteUser(ctx context.Context, userID string) error {
if err := r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", userID).Delete(&domain.User{}).Error; err != nil {
return err
}
if err := r.db.WithContext(ctx).Model(&domain.KBUsers{}).Where("user_id = ?", userID).Delete(&domain.KBUsers{}).Error; err != nil {
return err
}
return nil
}