Browse Source

Merge branch 'multi-dataset' of https://git.openi.org.cn/OpenI/aiforge into multi-dataset

pull/2384/head
zhoupzh 3 years ago
parent
commit
89a7cd5753
5 changed files with 99 additions and 20 deletions
  1. +8
    -1
      models/attachment.go
  2. +15
    -3
      models/dataset.go
  3. +5
    -14
      modules/modelarts/modelarts.go
  4. +2
    -0
      modules/setting/setting.go
  5. +69
    -2
      routers/repo/modelarts.go

+ 8
- 1
models/attachment.go View File

@@ -110,8 +110,15 @@ func (a *Attachment) IncreaseDownloadCount() error {
}

func IncreaseAttachmentUseNumber(uuid string) error {

uuidArray := strings.Split(uuid, ";")
for i := range uuidArray {
uuidArray[i] = "'" + uuidArray[i] + "'"
}

uuidInCondition := "(" + strings.Join(uuidArray, ",") + ")"
// Update use number.
if _, err := x.Exec("UPDATE `attachment` SET use_number=use_number+1 WHERE uuid=?", uuid); err != nil {
if _, err := x.Exec("UPDATE `attachment` SET use_number=use_number+1 WHERE uuid in " + uuidInCondition); err != nil {
return fmt.Errorf("increase attachment use count: %v", err)
}



+ 15
- 3
models/dataset.go View File

@@ -445,10 +445,22 @@ func UpdateDataset(ctx DBContext, rel *Dataset) error {
func IncreaseDatasetUseCount(uuid string) {

IncreaseAttachmentUseNumber(uuid)
attachments, _ := GetAttachmentsByUUIDs(strings.Split(uuid, ";"))

attachment, _ := GetAttachmentByUUID(uuid)
if attachment != nil {
x.Exec("UPDATE `dataset` SET use_count=use_count+1 WHERE id=?", attachment.DatasetID)
countMap := make(map[int64]int)

for _, attachment := range attachments {
value, ok := countMap[attachment.DatasetID]
if ok {
countMap[attachment.DatasetID] = value + 1
} else {
countMap[attachment.DatasetID] = 1
}

}

for key, value := range countMap {
x.Exec("UPDATE `dataset` SET use_count=use_count+? WHERE id=?", value, key)
}

}


+ 5
- 14
modules/modelarts/modelarts.go View File

@@ -1,13 +1,14 @@
package modelarts

import (
"code.gitea.io/gitea/modules/timeutil"
"encoding/json"
"errors"
"fmt"
"path"
"strconv"

"code.gitea.io/gitea/modules/timeutil"

"code.gitea.io/gitea/models"
"code.gitea.io/gitea/modules/context"
"code.gitea.io/gitea/modules/log"
@@ -96,6 +97,7 @@ type GenerateTrainJobReq struct {
VersionCount int
EngineName string
TotalVersionCount int
DatasetName string
}

type GenerateInferenceJobReq struct {
@@ -335,11 +337,6 @@ func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error
return err
}

attach, err := models.GetAttachmentByUUID(req.Uuid)
if err != nil {
log.Error("GetAttachmentByUUID(%s) failed:%v", strconv.FormatInt(jobResult.JobID, 10), err.Error())
return err
}
jobId := strconv.FormatInt(jobResult.JobID, 10)
err = models.CreateCloudbrain(&models.Cloudbrain{
Status: TransTrainJobStatus(jobResult.Status),
@@ -353,7 +350,7 @@ func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error
VersionID: jobResult.VersionID,
VersionName: jobResult.VersionName,
Uuid: req.Uuid,
DatasetName: attach.Name,
DatasetName: req.DatasetName,
CommitID: req.CommitID,
IsLatestVersion: req.IsLatestVersion,
ComputeResource: models.NPUResource,
@@ -408,12 +405,6 @@ func GenerateTrainJobVersion(ctx *context.Context, req *GenerateTrainJobReq, job
return err
}

attach, err := models.GetAttachmentByUUID(req.Uuid)
if err != nil {
log.Error("GetAttachmentByUUID(%s) failed:%v", strconv.FormatInt(jobResult.JobID, 10), err.Error())
return err
}

var jobTypes []string
jobTypes = append(jobTypes, string(models.JobTypeTrain))
repo := ctx.Repo.Repository
@@ -441,7 +432,7 @@ func GenerateTrainJobVersion(ctx *context.Context, req *GenerateTrainJobReq, job
VersionID: jobResult.VersionID,
VersionName: jobResult.VersionName,
Uuid: req.Uuid,
DatasetName: attach.Name,
DatasetName: req.DatasetName,
CommitID: req.CommitID,
IsLatestVersion: req.IsLatestVersion,
PreVersionName: req.PreVersionName,


+ 2
- 0
modules/setting/setting.go View File

@@ -465,6 +465,7 @@ var (
MaxDuration int64
TrainGpuTypes string
TrainResourceSpecs string
MaxDatasetNum int

//benchmark config
IsBenchmarkEnabled bool
@@ -1294,6 +1295,7 @@ func NewContext() {
MaxDuration = sec.Key("MAX_DURATION").MustInt64(14400)
TrainGpuTypes = sec.Key("TRAIN_GPU_TYPES").MustString("")
TrainResourceSpecs = sec.Key("TRAIN_RESOURCE_SPECS").MustString("")
MaxDatasetNum = sec.Key("MAX_DATASET_NUM").MustInt(5)

sec = Cfg.Section("benchmark")
IsBenchmarkEnabled = sec.Key("ENABLED").MustBool(false)


+ 69
- 2
routers/repo/modelarts.go View File

@@ -979,12 +979,26 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm)
codeObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.CodePath
outputObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.OutputPath + VersionOutputPath + "/"
logObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.LogPath + VersionOutputPath + "/"
dataPath := "/" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/"
// dataPath := "/" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/"
branch_name := form.BranchName
isLatestVersion := modelarts.IsLatestVersion
FlavorName := form.FlavorName
VersionCount := modelarts.VersionCount
EngineName := form.EngineName
if IsDatasetUseCountExceed(uuid) {
log.Error("DatasetUseCount is Exceed:%v")
trainJobErrorNewDataPrepare(ctx, form)
ctx.RenderWithErr("DatasetUseCount is Exceed", tplModelArtsTrainJobNew, &form)
return
}
datasetName, err := GetDatasetNameByUUID(uuid)
if err != nil {
log.Error("GetDatasetNameByUUID failed:%v", err, ctx.Data["MsgID"])
trainJobErrorNewDataPrepare(ctx, form)
ctx.RenderWithErr("GetDatasetNameByUUID error", tplModelArtsTrainJobNew, &form)
return
}
dataPath := GetObsDataPathByUUID(uuid)

count, err := models.GetCloudbrainTrainJobCountByUserID(ctx.User.ID)
if err != nil {
@@ -1161,6 +1175,7 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm)
EngineName: EngineName,
VersionCount: VersionCount,
TotalVersionCount: modelarts.TotalVersionCount,
DatasetName: datasetName,
}

//将params转换Parameters.Parameter,出错时返回给前端
@@ -1222,13 +1237,28 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ
codeObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.CodePath + VersionOutputPath + "/"
outputObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.OutputPath + VersionOutputPath + "/"
logObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.LogPath + VersionOutputPath + "/"
dataPath := "/" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/"
// dataPath := "/" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/"
branch_name := form.BranchName
PreVersionName := form.VersionName
FlavorName := form.FlavorName
EngineName := form.EngineName
isLatestVersion := modelarts.IsLatestVersion

if IsDatasetUseCountExceed(uuid) {
log.Error("DatasetUseCount is Exceed:%v")
versionErrorDataPrepare(ctx, form)
ctx.RenderWithErr("DatasetUseCount is Exceed", tplModelArtsTrainJobVersionNew, &form)
return
}
datasetName, err := GetDatasetNameByUUID(uuid)
if err != nil {
log.Error("GetDatasetNameByUUID failed:%v", err, ctx.Data["MsgID"])
versionErrorDataPrepare(ctx, form)
ctx.RenderWithErr("GetDatasetNameByUUID error", tplModelArtsTrainJobVersionNew, &form)
return
}
dataPath := GetObsDataPathByUUID(uuid)

canNewJob, _ := canUserCreateTrainJobVersion(ctx, latestTask.UserID)
if !canNewJob {
ctx.RenderWithErr("user cann't new trainjob", tplModelArtsTrainJobVersionNew, &form)
@@ -1386,6 +1416,7 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ
EngineName: EngineName,
PreVersionName: PreVersionName,
TotalVersionCount: latestTask.TotalVersionCount + 1,
DatasetName: datasetName,
}

err = modelarts.GenerateTrainJobVersion(ctx, req, jobID)
@@ -2420,3 +2451,39 @@ func TrainJobDownloadLogFile(ctx *context.Context) {
ctx.Resp.Header().Set("Cache-Control", "max-age=0")
http.Redirect(ctx.Resp, ctx.Req.Request, url, http.StatusMovedPermanently)
}
func GetObsDataPathByUUID(uuid string) string {
var obsDataPath string
uuidList := strings.Split(uuid, ";")
for k, _ := range uuidList {
if k <= 0 {
obsDataPath = "/" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/"
}
if k > 0 {
obsDataPathNext := ";" + "s3://" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/"
obsDataPath = obsDataPath + obsDataPathNext
}

}
return obsDataPath
}
func GetDatasetNameByUUID(uuid string) (string, error) {
uuidList := strings.Split(uuid, ";")
var datasetName string
for _, uuidStr := range uuidList {
attach, err := models.GetAttachmentByUUID(uuidStr)
if err != nil {
log.Error("GetAttachmentByUUID failed:%v", err)
return "", err
}
datasetName = datasetName + attach.Name + ";"
}
return datasetName, nil
}
func IsDatasetUseCountExceed(uuid string) bool {
uuidList := strings.Split(uuid, ";")
if len(uuidList) > setting.MaxDatasetNum {
return true
} else {
return false
}
}

Loading…
Cancel
Save