Browse Source

Merge branch 'multi-dataset' of https://git.openi.org.cn/OpenI/aiforge into multi-dataset

pull/2384/head
zhoupzh 3 years ago
parent
commit
89a7cd5753
5 changed files with 99 additions and 20 deletions
  1. +8
    -1
      models/attachment.go
  2. +15
    -3
      models/dataset.go
  3. +5
    -14
      modules/modelarts/modelarts.go
  4. +2
    -0
      modules/setting/setting.go
  5. +69
    -2
      routers/repo/modelarts.go

+ 8
- 1
models/attachment.go View File

@@ -110,8 +110,15 @@ func (a *Attachment) IncreaseDownloadCount() error {
} }


func IncreaseAttachmentUseNumber(uuid string) error { func IncreaseAttachmentUseNumber(uuid string) error {

uuidArray := strings.Split(uuid, ";")
for i := range uuidArray {
uuidArray[i] = "'" + uuidArray[i] + "'"
}

uuidInCondition := "(" + strings.Join(uuidArray, ",") + ")"
// Update use number. // Update use number.
if _, err := x.Exec("UPDATE `attachment` SET use_number=use_number+1 WHERE uuid=?", uuid); err != nil {
if _, err := x.Exec("UPDATE `attachment` SET use_number=use_number+1 WHERE uuid in " + uuidInCondition); err != nil {
return fmt.Errorf("increase attachment use count: %v", err) return fmt.Errorf("increase attachment use count: %v", err)
} }




+ 15
- 3
models/dataset.go View File

@@ -445,10 +445,22 @@ func UpdateDataset(ctx DBContext, rel *Dataset) error {
func IncreaseDatasetUseCount(uuid string) { func IncreaseDatasetUseCount(uuid string) {


IncreaseAttachmentUseNumber(uuid) IncreaseAttachmentUseNumber(uuid)
attachments, _ := GetAttachmentsByUUIDs(strings.Split(uuid, ";"))


attachment, _ := GetAttachmentByUUID(uuid)
if attachment != nil {
x.Exec("UPDATE `dataset` SET use_count=use_count+1 WHERE id=?", attachment.DatasetID)
countMap := make(map[int64]int)

for _, attachment := range attachments {
value, ok := countMap[attachment.DatasetID]
if ok {
countMap[attachment.DatasetID] = value + 1
} else {
countMap[attachment.DatasetID] = 1
}

}

for key, value := range countMap {
x.Exec("UPDATE `dataset` SET use_count=use_count+? WHERE id=?", value, key)
} }


} }


+ 5
- 14
modules/modelarts/modelarts.go View File

@@ -1,13 +1,14 @@
package modelarts package modelarts


import ( import (
"code.gitea.io/gitea/modules/timeutil"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"path" "path"
"strconv" "strconv"


"code.gitea.io/gitea/modules/timeutil"

"code.gitea.io/gitea/models" "code.gitea.io/gitea/models"
"code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/context"
"code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/log"
@@ -96,6 +97,7 @@ type GenerateTrainJobReq struct {
VersionCount int VersionCount int
EngineName string EngineName string
TotalVersionCount int TotalVersionCount int
DatasetName string
} }


type GenerateInferenceJobReq struct { type GenerateInferenceJobReq struct {
@@ -335,11 +337,6 @@ func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error
return err return err
} }


attach, err := models.GetAttachmentByUUID(req.Uuid)
if err != nil {
log.Error("GetAttachmentByUUID(%s) failed:%v", strconv.FormatInt(jobResult.JobID, 10), err.Error())
return err
}
jobId := strconv.FormatInt(jobResult.JobID, 10) jobId := strconv.FormatInt(jobResult.JobID, 10)
err = models.CreateCloudbrain(&models.Cloudbrain{ err = models.CreateCloudbrain(&models.Cloudbrain{
Status: TransTrainJobStatus(jobResult.Status), Status: TransTrainJobStatus(jobResult.Status),
@@ -353,7 +350,7 @@ func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error
VersionID: jobResult.VersionID, VersionID: jobResult.VersionID,
VersionName: jobResult.VersionName, VersionName: jobResult.VersionName,
Uuid: req.Uuid, Uuid: req.Uuid,
DatasetName: attach.Name,
DatasetName: req.DatasetName,
CommitID: req.CommitID, CommitID: req.CommitID,
IsLatestVersion: req.IsLatestVersion, IsLatestVersion: req.IsLatestVersion,
ComputeResource: models.NPUResource, ComputeResource: models.NPUResource,
@@ -408,12 +405,6 @@ func GenerateTrainJobVersion(ctx *context.Context, req *GenerateTrainJobReq, job
return err return err
} }


attach, err := models.GetAttachmentByUUID(req.Uuid)
if err != nil {
log.Error("GetAttachmentByUUID(%s) failed:%v", strconv.FormatInt(jobResult.JobID, 10), err.Error())
return err
}

var jobTypes []string var jobTypes []string
jobTypes = append(jobTypes, string(models.JobTypeTrain)) jobTypes = append(jobTypes, string(models.JobTypeTrain))
repo := ctx.Repo.Repository repo := ctx.Repo.Repository
@@ -441,7 +432,7 @@ func GenerateTrainJobVersion(ctx *context.Context, req *GenerateTrainJobReq, job
VersionID: jobResult.VersionID, VersionID: jobResult.VersionID,
VersionName: jobResult.VersionName, VersionName: jobResult.VersionName,
Uuid: req.Uuid, Uuid: req.Uuid,
DatasetName: attach.Name,
DatasetName: req.DatasetName,
CommitID: req.CommitID, CommitID: req.CommitID,
IsLatestVersion: req.IsLatestVersion, IsLatestVersion: req.IsLatestVersion,
PreVersionName: req.PreVersionName, PreVersionName: req.PreVersionName,


+ 2
- 0
modules/setting/setting.go View File

@@ -465,6 +465,7 @@ var (
MaxDuration int64 MaxDuration int64
TrainGpuTypes string TrainGpuTypes string
TrainResourceSpecs string TrainResourceSpecs string
MaxDatasetNum int


//benchmark config //benchmark config
IsBenchmarkEnabled bool IsBenchmarkEnabled bool
@@ -1294,6 +1295,7 @@ func NewContext() {
MaxDuration = sec.Key("MAX_DURATION").MustInt64(14400) MaxDuration = sec.Key("MAX_DURATION").MustInt64(14400)
TrainGpuTypes = sec.Key("TRAIN_GPU_TYPES").MustString("") TrainGpuTypes = sec.Key("TRAIN_GPU_TYPES").MustString("")
TrainResourceSpecs = sec.Key("TRAIN_RESOURCE_SPECS").MustString("") TrainResourceSpecs = sec.Key("TRAIN_RESOURCE_SPECS").MustString("")
MaxDatasetNum = sec.Key("MAX_DATASET_NUM").MustInt(5)


sec = Cfg.Section("benchmark") sec = Cfg.Section("benchmark")
IsBenchmarkEnabled = sec.Key("ENABLED").MustBool(false) IsBenchmarkEnabled = sec.Key("ENABLED").MustBool(false)


+ 69
- 2
routers/repo/modelarts.go View File

@@ -979,12 +979,26 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm)
codeObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.CodePath codeObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.CodePath
outputObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.OutputPath + VersionOutputPath + "/" outputObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.OutputPath + VersionOutputPath + "/"
logObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.LogPath + VersionOutputPath + "/" logObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.LogPath + VersionOutputPath + "/"
dataPath := "/" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/"
// dataPath := "/" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/"
branch_name := form.BranchName branch_name := form.BranchName
isLatestVersion := modelarts.IsLatestVersion isLatestVersion := modelarts.IsLatestVersion
FlavorName := form.FlavorName FlavorName := form.FlavorName
VersionCount := modelarts.VersionCount VersionCount := modelarts.VersionCount
EngineName := form.EngineName EngineName := form.EngineName
if IsDatasetUseCountExceed(uuid) {
log.Error("DatasetUseCount is Exceed:%v")
trainJobErrorNewDataPrepare(ctx, form)
ctx.RenderWithErr("DatasetUseCount is Exceed", tplModelArtsTrainJobNew, &form)
return
}
datasetName, err := GetDatasetNameByUUID(uuid)
if err != nil {
log.Error("GetDatasetNameByUUID failed:%v", err, ctx.Data["MsgID"])
trainJobErrorNewDataPrepare(ctx, form)
ctx.RenderWithErr("GetDatasetNameByUUID error", tplModelArtsTrainJobNew, &form)
return
}
dataPath := GetObsDataPathByUUID(uuid)


count, err := models.GetCloudbrainTrainJobCountByUserID(ctx.User.ID) count, err := models.GetCloudbrainTrainJobCountByUserID(ctx.User.ID)
if err != nil { if err != nil {
@@ -1161,6 +1175,7 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm)
EngineName: EngineName, EngineName: EngineName,
VersionCount: VersionCount, VersionCount: VersionCount,
TotalVersionCount: modelarts.TotalVersionCount, TotalVersionCount: modelarts.TotalVersionCount,
DatasetName: datasetName,
} }


//将params转换Parameters.Parameter,出错时返回给前端 //将params转换Parameters.Parameter,出错时返回给前端
@@ -1222,13 +1237,28 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ
codeObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.CodePath + VersionOutputPath + "/" codeObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.CodePath + VersionOutputPath + "/"
outputObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.OutputPath + VersionOutputPath + "/" outputObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.OutputPath + VersionOutputPath + "/"
logObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.LogPath + VersionOutputPath + "/" logObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.LogPath + VersionOutputPath + "/"
dataPath := "/" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/"
// dataPath := "/" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/"
branch_name := form.BranchName branch_name := form.BranchName
PreVersionName := form.VersionName PreVersionName := form.VersionName
FlavorName := form.FlavorName FlavorName := form.FlavorName
EngineName := form.EngineName EngineName := form.EngineName
isLatestVersion := modelarts.IsLatestVersion isLatestVersion := modelarts.IsLatestVersion


if IsDatasetUseCountExceed(uuid) {
log.Error("DatasetUseCount is Exceed:%v")
versionErrorDataPrepare(ctx, form)
ctx.RenderWithErr("DatasetUseCount is Exceed", tplModelArtsTrainJobVersionNew, &form)
return
}
datasetName, err := GetDatasetNameByUUID(uuid)
if err != nil {
log.Error("GetDatasetNameByUUID failed:%v", err, ctx.Data["MsgID"])
versionErrorDataPrepare(ctx, form)
ctx.RenderWithErr("GetDatasetNameByUUID error", tplModelArtsTrainJobVersionNew, &form)
return
}
dataPath := GetObsDataPathByUUID(uuid)

canNewJob, _ := canUserCreateTrainJobVersion(ctx, latestTask.UserID) canNewJob, _ := canUserCreateTrainJobVersion(ctx, latestTask.UserID)
if !canNewJob { if !canNewJob {
ctx.RenderWithErr("user cann't new trainjob", tplModelArtsTrainJobVersionNew, &form) ctx.RenderWithErr("user cann't new trainjob", tplModelArtsTrainJobVersionNew, &form)
@@ -1386,6 +1416,7 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ
EngineName: EngineName, EngineName: EngineName,
PreVersionName: PreVersionName, PreVersionName: PreVersionName,
TotalVersionCount: latestTask.TotalVersionCount + 1, TotalVersionCount: latestTask.TotalVersionCount + 1,
DatasetName: datasetName,
} }


err = modelarts.GenerateTrainJobVersion(ctx, req, jobID) err = modelarts.GenerateTrainJobVersion(ctx, req, jobID)
@@ -2420,3 +2451,39 @@ func TrainJobDownloadLogFile(ctx *context.Context) {
ctx.Resp.Header().Set("Cache-Control", "max-age=0") ctx.Resp.Header().Set("Cache-Control", "max-age=0")
http.Redirect(ctx.Resp, ctx.Req.Request, url, http.StatusMovedPermanently) http.Redirect(ctx.Resp, ctx.Req.Request, url, http.StatusMovedPermanently)
} }
func GetObsDataPathByUUID(uuid string) string {
var obsDataPath string
uuidList := strings.Split(uuid, ";")
for k, _ := range uuidList {
if k <= 0 {
obsDataPath = "/" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/"
}
if k > 0 {
obsDataPathNext := ";" + "s3://" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/"
obsDataPath = obsDataPath + obsDataPathNext
}

}
return obsDataPath
}
func GetDatasetNameByUUID(uuid string) (string, error) {
uuidList := strings.Split(uuid, ";")
var datasetName string
for _, uuidStr := range uuidList {
attach, err := models.GetAttachmentByUUID(uuidStr)
if err != nil {
log.Error("GetAttachmentByUUID failed:%v", err)
return "", err
}
datasetName = datasetName + attach.Name + ";"
}
return datasetName, nil
}
func IsDatasetUseCountExceed(uuid string) bool {
uuidList := strings.Split(uuid, ";")
if len(uuidList) > setting.MaxDatasetNum {
return true
} else {
return false
}
}

Loading…
Cancel
Save