Browse Source

Merge pull request 'fix-test2256' (#2345) from fix-test2256 into multi-dataset

Reviewed-on: https://git.openi.org.cn/OpenI/aiforge/pulls/2345
pull/2384/head
liuzx 3 years ago
parent
commit
f17e567a8f
3 changed files with 76 additions and 16 deletions
  1. +5
    -14
      modules/modelarts/modelarts.go
  2. +2
    -0
      modules/setting/setting.go
  3. +69
    -2
      routers/repo/modelarts.go

+ 5
- 14
modules/modelarts/modelarts.go View File

@@ -1,13 +1,14 @@
package modelarts

import (
"code.gitea.io/gitea/modules/timeutil"
"encoding/json"
"errors"
"fmt"
"path"
"strconv"

"code.gitea.io/gitea/modules/timeutil"

"code.gitea.io/gitea/models"
"code.gitea.io/gitea/modules/context"
"code.gitea.io/gitea/modules/log"
@@ -96,6 +97,7 @@ type GenerateTrainJobReq struct {
VersionCount int
EngineName string
TotalVersionCount int
DatasetName string
}

type GenerateInferenceJobReq struct {
@@ -335,11 +337,6 @@ func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error
return err
}

attach, err := models.GetAttachmentByUUID(req.Uuid)
if err != nil {
log.Error("GetAttachmentByUUID(%s) failed:%v", strconv.FormatInt(jobResult.JobID, 10), err.Error())
return err
}
jobId := strconv.FormatInt(jobResult.JobID, 10)
err = models.CreateCloudbrain(&models.Cloudbrain{
Status: TransTrainJobStatus(jobResult.Status),
@@ -353,7 +350,7 @@ func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error
VersionID: jobResult.VersionID,
VersionName: jobResult.VersionName,
Uuid: req.Uuid,
DatasetName: attach.Name,
DatasetName: req.DatasetName,
CommitID: req.CommitID,
IsLatestVersion: req.IsLatestVersion,
ComputeResource: models.NPUResource,
@@ -408,12 +405,6 @@ func GenerateTrainJobVersion(ctx *context.Context, req *GenerateTrainJobReq, job
return err
}

attach, err := models.GetAttachmentByUUID(req.Uuid)
if err != nil {
log.Error("GetAttachmentByUUID(%s) failed:%v", strconv.FormatInt(jobResult.JobID, 10), err.Error())
return err
}

var jobTypes []string
jobTypes = append(jobTypes, string(models.JobTypeTrain))
repo := ctx.Repo.Repository
@@ -441,7 +432,7 @@ func GenerateTrainJobVersion(ctx *context.Context, req *GenerateTrainJobReq, job
VersionID: jobResult.VersionID,
VersionName: jobResult.VersionName,
Uuid: req.Uuid,
DatasetName: attach.Name,
DatasetName: req.DatasetName,
CommitID: req.CommitID,
IsLatestVersion: req.IsLatestVersion,
PreVersionName: req.PreVersionName,


+ 2
- 0
modules/setting/setting.go View File

@@ -465,6 +465,7 @@ var (
MaxDuration int64
TrainGpuTypes string
TrainResourceSpecs string
MaxDatasetNum int

//benchmark config
IsBenchmarkEnabled bool
@@ -1294,6 +1295,7 @@ func NewContext() {
MaxDuration = sec.Key("MAX_DURATION").MustInt64(14400)
TrainGpuTypes = sec.Key("TRAIN_GPU_TYPES").MustString("")
TrainResourceSpecs = sec.Key("TRAIN_RESOURCE_SPECS").MustString("")
MaxDatasetNum = sec.Key("MAX_DATASET_NUM").MustInt(5)

sec = Cfg.Section("benchmark")
IsBenchmarkEnabled = sec.Key("ENABLED").MustBool(false)


+ 69
- 2
routers/repo/modelarts.go View File

@@ -979,12 +979,26 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm)
codeObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.CodePath
outputObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.OutputPath + VersionOutputPath + "/"
logObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.LogPath + VersionOutputPath + "/"
dataPath := "/" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/"
// dataPath := "/" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/"
branch_name := form.BranchName
isLatestVersion := modelarts.IsLatestVersion
FlavorName := form.FlavorName
VersionCount := modelarts.VersionCount
EngineName := form.EngineName
if IsDatasetUseCountExceed(uuid) {
log.Error("DatasetUseCount is Exceed:%v")
trainJobErrorNewDataPrepare(ctx, form)
ctx.RenderWithErr("DatasetUseCount is Exceed", tplModelArtsTrainJobNew, &form)
return
}
datasetName, err := GetDatasetNameByUUID(uuid)
if err != nil {
log.Error("GetDatasetNameByUUID failed:%v", err, ctx.Data["MsgID"])
trainJobErrorNewDataPrepare(ctx, form)
ctx.RenderWithErr("GetDatasetNameByUUID error", tplModelArtsTrainJobNew, &form)
return
}
dataPath := GetObsDataPathByUUID(uuid)

count, err := models.GetCloudbrainTrainJobCountByUserID(ctx.User.ID)
if err != nil {
@@ -1161,6 +1175,7 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm)
EngineName: EngineName,
VersionCount: VersionCount,
TotalVersionCount: modelarts.TotalVersionCount,
DatasetName: datasetName,
}

//将params转换Parameters.Parameter,出错时返回给前端
@@ -1222,13 +1237,28 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ
codeObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.CodePath + VersionOutputPath + "/"
outputObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.OutputPath + VersionOutputPath + "/"
logObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.LogPath + VersionOutputPath + "/"
dataPath := "/" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/"
// dataPath := "/" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/"
branch_name := form.BranchName
PreVersionName := form.VersionName
FlavorName := form.FlavorName
EngineName := form.EngineName
isLatestVersion := modelarts.IsLatestVersion

if IsDatasetUseCountExceed(uuid) {
log.Error("DatasetUseCount is Exceed:%v")
versionErrorDataPrepare(ctx, form)
ctx.RenderWithErr("DatasetUseCount is Exceed", tplModelArtsTrainJobVersionNew, &form)
return
}
datasetName, err := GetDatasetNameByUUID(uuid)
if err != nil {
log.Error("GetDatasetNameByUUID failed:%v", err, ctx.Data["MsgID"])
versionErrorDataPrepare(ctx, form)
ctx.RenderWithErr("GetDatasetNameByUUID error", tplModelArtsTrainJobVersionNew, &form)
return
}
dataPath := GetObsDataPathByUUID(uuid)

canNewJob, _ := canUserCreateTrainJobVersion(ctx, latestTask.UserID)
if !canNewJob {
ctx.RenderWithErr("user cann't new trainjob", tplModelArtsTrainJobVersionNew, &form)
@@ -1386,6 +1416,7 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ
EngineName: EngineName,
PreVersionName: PreVersionName,
TotalVersionCount: latestTask.TotalVersionCount + 1,
DatasetName: datasetName,
}

err = modelarts.GenerateTrainJobVersion(ctx, req, jobID)
@@ -2420,3 +2451,39 @@ func TrainJobDownloadLogFile(ctx *context.Context) {
ctx.Resp.Header().Set("Cache-Control", "max-age=0")
http.Redirect(ctx.Resp, ctx.Req.Request, url, http.StatusMovedPermanently)
}
func GetObsDataPathByUUID(uuid string) string {
var obsDataPath string
uuidList := strings.Split(uuid, ";")
for k, _ := range uuidList {
if k <= 0 {
obsDataPath = "/" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/"
}
if k > 0 {
obsDataPathNext := ";" + "s3://" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/"
obsDataPath = obsDataPath + obsDataPathNext
}

}
return obsDataPath
}
func GetDatasetNameByUUID(uuid string) (string, error) {
uuidList := strings.Split(uuid, ";")
var datasetName string
for _, uuidStr := range uuidList {
attach, err := models.GetAttachmentByUUID(uuidStr)
if err != nil {
log.Error("GetAttachmentByUUID failed:%v", err)
return "", err
}
datasetName = datasetName + attach.Name + ";"
}
return datasetName, nil
}
func IsDatasetUseCountExceed(uuid string) bool {
uuidList := strings.Split(uuid, ";")
if len(uuidList) > setting.MaxDatasetNum {
return true
} else {
return false
}
}

Loading…
Cancel
Save