init push
This commit is contained in:
342
backend/usecase/file.go
Normal file
342
backend/usecase/file.go
Normal file
@@ -0,0 +1,342 @@
|
||||
package usecase
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/minio/minio-go/v7"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/chaitin/panda-wiki/config"
|
||||
"github.com/chaitin/panda-wiki/consts"
|
||||
"github.com/chaitin/panda-wiki/domain"
|
||||
"github.com/chaitin/panda-wiki/log"
|
||||
"github.com/chaitin/panda-wiki/repo/pg"
|
||||
"github.com/chaitin/panda-wiki/store/s3"
|
||||
"github.com/chaitin/panda-wiki/utils"
|
||||
)
|
||||
|
||||
type FileUsecase struct {
|
||||
logger *log.Logger
|
||||
s3Client *s3.MinioClient
|
||||
config *config.Config
|
||||
systemSettingRepo *pg.SystemSettingRepo
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewFileUsecase(logger *log.Logger, s3Client *s3.MinioClient, config *config.Config, systemSettingRepo *pg.SystemSettingRepo) *FileUsecase {
|
||||
return &FileUsecase{
|
||||
s3Client: s3Client,
|
||||
logger: logger.WithModule("usecase.file"),
|
||||
config: config,
|
||||
systemSettingRepo: systemSettingRepo,
|
||||
httpClient: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
},
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
// Prevent redirects to bypass SSRF checks
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (u *FileUsecase) UploadFileGetUrl(ctx context.Context, kbID string, file *multipart.FileHeader) (string, error) {
|
||||
key, err := u.UploadFile(ctx, kbID, file)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("http://panda-wiki-minio:9000/static-file/%s", key), nil
|
||||
}
|
||||
|
||||
func (u *FileUsecase) UploadFile(ctx context.Context, kbID string, file *multipart.FileHeader) (string, error) {
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
ext := strings.ToLower(filepath.Ext(file.Filename))
|
||||
|
||||
// Check denied extensions
|
||||
if err := u.checkDeniedExtension(ctx, ext); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
filename := fmt.Sprintf("%s/%s%s", kbID, uuid.New().String(), ext)
|
||||
|
||||
size := file.Size
|
||||
|
||||
contentType := file.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = mime.TypeByExtension(ext)
|
||||
}
|
||||
|
||||
resp, err := u.s3Client.PutObject(
|
||||
ctx,
|
||||
domain.Bucket,
|
||||
filename,
|
||||
src,
|
||||
size,
|
||||
minio.PutObjectOptions{
|
||||
ContentType: contentType,
|
||||
UserMetadata: map[string]string{
|
||||
"originalname": file.Filename,
|
||||
},
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("upload failed: %w", err)
|
||||
}
|
||||
|
||||
return resp.Key, nil
|
||||
}
|
||||
|
||||
func (u *FileUsecase) UploadFileFromBytes(ctx context.Context, kbID string, filename string, fileBytes []byte) (string, error) {
|
||||
// Create a reader from the byte slice
|
||||
reader := bytes.NewReader(fileBytes)
|
||||
|
||||
ext := strings.ToLower(filepath.Ext(filename))
|
||||
|
||||
// Check denied extensions
|
||||
if err := u.checkDeniedExtension(ctx, ext); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
s3Filename := fmt.Sprintf("%s/%s%s", kbID, uuid.New().String(), ext)
|
||||
|
||||
size := int64(len(fileBytes))
|
||||
|
||||
contentType := mime.TypeByExtension(ext)
|
||||
if contentType == "" {
|
||||
// Fallback content type if extension not recognized
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
|
||||
resp, err := u.s3Client.PutObject(
|
||||
ctx,
|
||||
domain.Bucket,
|
||||
s3Filename,
|
||||
reader,
|
||||
size,
|
||||
minio.PutObjectOptions{
|
||||
ContentType: contentType,
|
||||
UserMetadata: map[string]string{
|
||||
"originalname": filename,
|
||||
},
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("upload failed: %w", err)
|
||||
}
|
||||
|
||||
return resp.Key, nil
|
||||
}
|
||||
|
||||
func (u *FileUsecase) UploadFileFromReader(
|
||||
ctx context.Context,
|
||||
kbID string,
|
||||
filename string,
|
||||
reader io.Reader,
|
||||
size int64, // 必须提供对象大小
|
||||
) (string, error) {
|
||||
// 生成唯一文件名
|
||||
ext := strings.ToLower(filepath.Ext(filename))
|
||||
|
||||
// Check denied extensions
|
||||
if err := u.checkDeniedExtension(ctx, ext); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
s3Filename := fmt.Sprintf("%s/%s%s", kbID, uuid.New().String(), ext)
|
||||
|
||||
// 获取内容类型
|
||||
contentType := mime.TypeByExtension(ext)
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream" // 默认类型
|
||||
}
|
||||
|
||||
// 上传到 S3
|
||||
_, err := u.s3Client.PutObject(
|
||||
ctx,
|
||||
domain.Bucket,
|
||||
s3Filename,
|
||||
reader,
|
||||
size, // 必须提供对象大小
|
||||
minio.PutObjectOptions{
|
||||
ContentType: contentType,
|
||||
UserMetadata: map[string]string{
|
||||
"originalname": filename,
|
||||
},
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("S3 upload failed: %w", err)
|
||||
}
|
||||
|
||||
return s3Filename, nil
|
||||
}
|
||||
|
||||
func (u *FileUsecase) AnyDocUploadFile(ctx context.Context, file *multipart.FileHeader, path string) (string, error) {
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
ext := strings.ToLower(filepath.Ext(file.Filename))
|
||||
|
||||
// Check denied extensions
|
||||
if err := u.checkDeniedExtension(ctx, ext); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
size := file.Size
|
||||
|
||||
contentType := file.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = mime.TypeByExtension(ext)
|
||||
}
|
||||
|
||||
resp, err := u.s3Client.PutObject(
|
||||
ctx,
|
||||
domain.Bucket,
|
||||
path,
|
||||
src,
|
||||
size,
|
||||
minio.PutObjectOptions{
|
||||
ContentType: contentType,
|
||||
UserMetadata: map[string]string{
|
||||
"originalname": file.Filename,
|
||||
},
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("upload failed: %w", err)
|
||||
}
|
||||
|
||||
return resp.Key, nil
|
||||
}
|
||||
|
||||
func (u *FileUsecase) UploadFileByUrl(ctx context.Context, kbID string, fileURL string) (string, error) {
|
||||
// Validate URL to prevent SSRF attacks
|
||||
if err := utils.ValidateURLForSSRF(fileURL); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fileURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := u.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to download file: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Handle redirects manually to re-validate each redirect target
|
||||
if resp.StatusCode >= 300 && resp.StatusCode < 400 {
|
||||
return "", fmt.Errorf("redirects are not allowed for security reasons")
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("failed to download file, status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
const maxRemoteFileSize = 50 * 1024 * 1024 // 50MB
|
||||
lr := io.LimitReader(resp.Body, maxRemoteFileSize+1)
|
||||
data, err := io.ReadAll(lr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
if len(data) > maxRemoteFileSize {
|
||||
return "", fmt.Errorf("failed to read response body: file size exceeds limit of %d bytes", maxRemoteFileSize)
|
||||
}
|
||||
|
||||
urlPath := fileURL
|
||||
if idx := strings.Index(urlPath, "?"); idx != -1 {
|
||||
urlPath = urlPath[:idx]
|
||||
}
|
||||
ext := strings.ToLower(filepath.Ext(urlPath))
|
||||
|
||||
if err := u.checkDeniedExtension(ctx, ext); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
s3Filename := fmt.Sprintf("%s/%s%s", kbID, uuid.New().String(), ext)
|
||||
|
||||
// Derive content type from the actual data instead of trusting the remote header
|
||||
contentType := http.DetectContentType(data)
|
||||
if contentType == "" || contentType == "application/octet-stream" {
|
||||
if extType := mime.TypeByExtension(ext); extType != "" {
|
||||
contentType = extType
|
||||
} else {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
}
|
||||
|
||||
putResp, err := u.s3Client.PutObject(
|
||||
ctx,
|
||||
domain.Bucket,
|
||||
s3Filename,
|
||||
bytes.NewReader(data),
|
||||
int64(len(data)),
|
||||
minio.PutObjectOptions{
|
||||
ContentType: contentType,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("upload failed: %w", err)
|
||||
}
|
||||
|
||||
return putResp.Key, nil
|
||||
}
|
||||
|
||||
// checkDeniedExtension checks if the file extension is in the denied list
|
||||
func (u *FileUsecase) checkDeniedExtension(ctx context.Context, ext string) error {
|
||||
// Remove leading dot from extension
|
||||
ext = strings.TrimPrefix(ext, ".")
|
||||
if ext == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get denied extensions from system settings
|
||||
setting, err := u.systemSettingRepo.GetSystemSetting(ctx, consts.SystemSettingUpload)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil
|
||||
}
|
||||
u.logger.Error("failed to get upload denied extensions setting", "error", err)
|
||||
return nil // Don't block upload if we can't read settings
|
||||
}
|
||||
|
||||
var deniedSetting domain.UploadDeniedExtensionsSetting
|
||||
if err := json.Unmarshal(setting.Value, &deniedSetting); err != nil {
|
||||
u.logger.Error("failed to unmarshal denied extensions setting", "error", err)
|
||||
return nil // Don't block upload if settings are malformed
|
||||
}
|
||||
|
||||
// Check if extension is denied
|
||||
for _, deniedExt := range deniedSetting.DeniedExtensions {
|
||||
if strings.EqualFold(ext, deniedExt) {
|
||||
return fmt.Errorf("file extension '.%s' is not allowed for upload", ext)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user