Dewvine/main.go
2024-08-16 18:34:05 +08:00

176 lines
3.7 KiB
Go

package main
import (
"encoding/json"
"github.com/gin-gonic/gin"
"math/rand"
"merak.axiomatrix.org/Axiomatrix_Org/optimized_go_tools/am_cors"
"merak.axiomatrix.org/Axiomatrix_Org/optimized_go_tools/am_ratelimit"
"os"
"path/filepath"
"strconv"
"strings"
"time"
)
type Data struct {
UUID string `json:"uuid"`
Sentence string `json:"sentence"`
Type string `json:"type"`
From string `json:"from"`
CreatedAt uint `json:"created_at"`
Length int `json:"length"`
}
var dataDir = "./data"
var dataTypes = []string{"internet", "literature", "philosophy", "poem"}
var defaultLimitConfig = am_ratelimit.NewRateLimitConfig(5, 1)
func main() {
gin.SetMode(gin.ReleaseMode)
r := gin.Default()
r.Use(am_cors.Cors())
r.GET("/", defaultLimitConfig.RateLimitMiddleware, func(c *gin.Context) {
t := c.Query("type") // 要求的語料類型
lang := c.Query("lang") // 简体中文 / 繁體中文
ls := c.Query("length") // 要求的語料最大長度
var l int
ns := c.Query("number") // 要求的語料數量
var n int
var dataDirNew string
if lang == "zh-CN" {
dataDirNew = filepath.Join(dataDir, "zh-CN")
} else if lang == "zh-TW" {
dataDirNew = filepath.Join(dataDir, "zh-TW")
} else {
if lang == "" {
dataDirNew = filepath.Join(dataDir, "zh-TW")
} else {
c.JSON(400, gin.H{"error": "invalid lang"})
return
}
}
if ls != "" {
atoi, err := strconv.Atoi(ls)
if err != nil {
c.JSON(500, gin.H{"error": err.Error()})
return
}
l = atoi
} else {
l = 100000000
}
if ns != "" {
atoi, err := strconv.Atoi(ns)
if err != nil {
c.JSON(500, gin.H{"error": err.Error()})
return
}
n = atoi
} else {
n = 1
}
var r int // 符合要求的語料庫長度
var contents []Data // 符合要求的語料庫
if t == "" { // 未指定語料類型,則從所有語料中隨機選取
files, err := os.ReadDir(dataDirNew)
if err != nil {
c.JSON(500, gin.H{"error": err.Error()})
return
}
for _, file := range files {
filePath := filepath.Join(dataDirNew, file.Name())
datas, err := GetFile(filePath)
if err != nil {
c.JSON(500, gin.H{"error": err.Error()})
return
}
contents = append(contents, datas...)
}
r = len(contents)
} else {
ts := strings.Split(t, ",")
for _, t := range ts {
if !Contains(dataTypes, t) {
c.JSON(400, gin.H{"error": "invalid type"})
return
}
}
files, err := os.ReadDir(dataDirNew)
if err != nil {
c.JSON(500, gin.H{"error": err.Error()})
return
}
for _, file := range files {
if Contains(ts, strings.TrimSuffix(file.Name(), ".json")) {
filePath := filepath.Join(dataDirNew, file.Name())
datas, err := GetFile(filePath)
if err != nil {
c.JSON(500, gin.H{"error": err.Error()})
return
}
contents = append(contents, datas...)
}
}
r = len(contents)
}
// 隨機產生語料
var results []Data
for i := 0; i < n; i++ {
var result Data
for {
random := rand.New(rand.NewSource(time.Now().UnixNano()))
randN := random.Intn(r)
content := contents[randN]
if content.Length <= l {
result = content
break
}
}
results = append(results, result)
}
c.JSON(200, results)
})
err := r.Run(":7080")
if err != nil {
panic(err)
}
}
func GetFile(filePath string) ([]Data, error) {
content, err := os.ReadFile(filePath)
if err != nil {
return nil, err
}
var datas []Data
err = json.Unmarshal(content, &datas)
if err != nil {
return nil, err
}
return datas, nil
}
func Contains(arr []string, target string) bool {
for _, a := range arr {
if a == target {
return true
}
}
return false
}