147 lines
4.2 KiB
Go
147 lines
4.2 KiB
Go
package pg
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
|
|
"github.com/samber/lo"
|
|
"golang.org/x/crypto/bcrypt"
|
|
"gorm.io/gorm"
|
|
|
|
v1 "github.com/chaitin/panda-wiki/api/user/v1"
|
|
"github.com/chaitin/panda-wiki/consts"
|
|
"github.com/chaitin/panda-wiki/domain"
|
|
"github.com/chaitin/panda-wiki/log"
|
|
"github.com/chaitin/panda-wiki/store/pg"
|
|
)
|
|
|
|
type UserRepository struct {
|
|
db *pg.DB
|
|
logger *log.Logger
|
|
}
|
|
|
|
func NewUserRepository(db *pg.DB, logger *log.Logger) *UserRepository {
|
|
return &UserRepository{
|
|
db: db,
|
|
logger: logger.WithModule("repo.pg.user"),
|
|
}
|
|
}
|
|
|
|
func (r *UserRepository) UpsertDefaultUser(ctx context.Context, user *domain.User) error {
|
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to hash password: %w", err)
|
|
}
|
|
user.Password = string(hashedPassword)
|
|
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
// First try to find existing user
|
|
var existingUser domain.User
|
|
err := tx.Where("account = ?", user.Account).First(&existingUser).Error
|
|
if err != nil {
|
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return err
|
|
}
|
|
// User doesn't exist, create new user
|
|
if err := tx.Create(user).Error; err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
// User exists, update password
|
|
return tx.Model(&existingUser).Update("password", user.Password).Error
|
|
})
|
|
}
|
|
|
|
func (r *UserRepository) CreateUser(ctx context.Context, user *domain.User, edition consts.LicenseEdition) error {
|
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to hash password: %w", err)
|
|
}
|
|
user.Password = string(hashedPassword)
|
|
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
var count int64
|
|
if err := tx.Model(&domain.User{}).Count(&count).Error; err != nil {
|
|
return err
|
|
}
|
|
if count >= domain.GetBaseEditionLimitation(ctx).MaxAdmin {
|
|
return fmt.Errorf("exceed max admin limit, current count: %d, max limit: %d", count, domain.GetBaseEditionLimitation(ctx).MaxAdmin)
|
|
}
|
|
|
|
if err := tx.Create(user).Error; err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (r *UserRepository) VerifyUser(ctx context.Context, account string, password string) (*domain.User, error) {
|
|
var user domain.User
|
|
err := r.db.WithContext(ctx).Where("account = ?", account).First(&user).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil {
|
|
return nil, errors.New("invalid password")
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
func (r *UserRepository) GetUser(ctx context.Context, userID string) (*domain.User, error) {
|
|
var user domain.User
|
|
err := r.db.WithContext(ctx).
|
|
Where("id = ?", userID).
|
|
First(&user).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
func (r *UserRepository) ListUsers(ctx context.Context) ([]v1.UserListItemResp, error) {
|
|
var users []v1.UserListItemResp
|
|
err := r.db.WithContext(ctx).
|
|
Model(&domain.User{}).
|
|
Order("created_at DESC").
|
|
Find(&users).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return users, nil
|
|
}
|
|
|
|
func (r *UserRepository) GetUsersAccountMap(ctx context.Context) (map[string]string, error) {
|
|
var users []v1.UserListItemResp
|
|
err := r.db.WithContext(ctx).
|
|
Model(&domain.User{}).
|
|
Find(&users).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
m := lo.SliceToMap(users, func(user v1.UserListItemResp) (string, string) {
|
|
return user.ID, user.Account
|
|
})
|
|
|
|
return m, nil
|
|
}
|
|
|
|
func (r *UserRepository) UpdateUserPassword(ctx context.Context, userID string, newPassword string) error {
|
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to hash password: %w", err)
|
|
}
|
|
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", userID).Update("password", string(hashedPassword)).Error
|
|
}
|
|
|
|
func (r *UserRepository) DeleteUser(ctx context.Context, userID string) error {
|
|
if err := r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", userID).Delete(&domain.User{}).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := r.db.WithContext(ctx).Model(&domain.KBUsers{}).Where("user_id = ?", userID).Delete(&domain.KBUsers{}).Error; err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|