diff --git a/modules/modelarts/modelarts.go b/modules/modelarts/modelarts.go index 78b40fd56..5988e62bb 100755 --- a/modules/modelarts/modelarts.go +++ b/modules/modelarts/modelarts.go @@ -1,13 +1,14 @@ package modelarts import ( - "code.gitea.io/gitea/modules/timeutil" "encoding/json" "errors" "fmt" "path" "strconv" + "code.gitea.io/gitea/modules/timeutil" + "code.gitea.io/gitea/models" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/log" @@ -96,6 +97,7 @@ type GenerateTrainJobReq struct { VersionCount int EngineName string TotalVersionCount int + DatasetName string } type GenerateInferenceJobReq struct { @@ -335,11 +337,6 @@ func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error return err } - attach, err := models.GetAttachmentByUUID(req.Uuid) - if err != nil { - log.Error("GetAttachmentByUUID(%s) failed:%v", strconv.FormatInt(jobResult.JobID, 10), err.Error()) - return err - } jobId := strconv.FormatInt(jobResult.JobID, 10) err = models.CreateCloudbrain(&models.Cloudbrain{ Status: TransTrainJobStatus(jobResult.Status), @@ -353,7 +350,7 @@ func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error VersionID: jobResult.VersionID, VersionName: jobResult.VersionName, Uuid: req.Uuid, - DatasetName: attach.Name, + DatasetName: req.DatasetName, CommitID: req.CommitID, IsLatestVersion: req.IsLatestVersion, ComputeResource: models.NPUResource, @@ -408,12 +405,6 @@ func GenerateTrainJobVersion(ctx *context.Context, req *GenerateTrainJobReq, job return err } - attach, err := models.GetAttachmentByUUID(req.Uuid) - if err != nil { - log.Error("GetAttachmentByUUID(%s) failed:%v", strconv.FormatInt(jobResult.JobID, 10), err.Error()) - return err - } - var jobTypes []string jobTypes = append(jobTypes, string(models.JobTypeTrain)) repo := ctx.Repo.Repository @@ -441,7 +432,7 @@ func GenerateTrainJobVersion(ctx *context.Context, req *GenerateTrainJobReq, job VersionID: jobResult.VersionID, VersionName: jobResult.VersionName, Uuid: req.Uuid, - DatasetName: attach.Name, + DatasetName: req.DatasetName, CommitID: req.CommitID, IsLatestVersion: req.IsLatestVersion, PreVersionName: req.PreVersionName, diff --git a/modules/setting/setting.go b/modules/setting/setting.go index 5c87b68c5..8c26c7b9e 100755 --- a/modules/setting/setting.go +++ b/modules/setting/setting.go @@ -465,6 +465,7 @@ var ( MaxDuration int64 TrainGpuTypes string TrainResourceSpecs string + MaxDatasetNum int //benchmark config IsBenchmarkEnabled bool @@ -1294,6 +1295,7 @@ func NewContext() { MaxDuration = sec.Key("MAX_DURATION").MustInt64(14400) TrainGpuTypes = sec.Key("TRAIN_GPU_TYPES").MustString("") TrainResourceSpecs = sec.Key("TRAIN_RESOURCE_SPECS").MustString("") + MaxDatasetNum = sec.Key("MAX_DATASET_NUM").MustInt(5) sec = Cfg.Section("benchmark") IsBenchmarkEnabled = sec.Key("ENABLED").MustBool(false) diff --git a/routers/repo/modelarts.go b/routers/repo/modelarts.go index 12a5a0623..0a4a9bbc1 100755 --- a/routers/repo/modelarts.go +++ b/routers/repo/modelarts.go @@ -979,12 +979,26 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) codeObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.CodePath outputObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.OutputPath + VersionOutputPath + "/" logObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.LogPath + VersionOutputPath + "/" - dataPath := "/" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/" + // dataPath := "/" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/" branch_name := form.BranchName isLatestVersion := modelarts.IsLatestVersion FlavorName := form.FlavorName VersionCount := modelarts.VersionCount EngineName := form.EngineName + if IsDatasetUseCountExceed(uuid) { + log.Error("DatasetUseCount is Exceed:%v") + trainJobErrorNewDataPrepare(ctx, form) + ctx.RenderWithErr("DatasetUseCount is Exceed", tplModelArtsTrainJobNew, &form) + return + } + datasetName, err := GetDatasetNameByUUID(uuid) + if err != nil { + log.Error("GetDatasetNameByUUID failed:%v", err, ctx.Data["MsgID"]) + trainJobErrorNewDataPrepare(ctx, form) + ctx.RenderWithErr("GetDatasetNameByUUID error", tplModelArtsTrainJobNew, &form) + return + } + dataPath := GetObsDataPathByUUID(uuid) count, err := models.GetCloudbrainTrainJobCountByUserID(ctx.User.ID) if err != nil { @@ -1161,6 +1175,7 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) EngineName: EngineName, VersionCount: VersionCount, TotalVersionCount: modelarts.TotalVersionCount, + DatasetName: datasetName, } //将params转换Parameters.Parameter,出错时返回给前端 @@ -1222,13 +1237,28 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ codeObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.CodePath + VersionOutputPath + "/" outputObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.OutputPath + VersionOutputPath + "/" logObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.LogPath + VersionOutputPath + "/" - dataPath := "/" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/" + // dataPath := "/" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/" branch_name := form.BranchName PreVersionName := form.VersionName FlavorName := form.FlavorName EngineName := form.EngineName isLatestVersion := modelarts.IsLatestVersion + if IsDatasetUseCountExceed(uuid) { + log.Error("DatasetUseCount is Exceed:%v") + versionErrorDataPrepare(ctx, form) + ctx.RenderWithErr("DatasetUseCount is Exceed", tplModelArtsTrainJobVersionNew, &form) + return + } + datasetName, err := GetDatasetNameByUUID(uuid) + if err != nil { + log.Error("GetDatasetNameByUUID failed:%v", err, ctx.Data["MsgID"]) + versionErrorDataPrepare(ctx, form) + ctx.RenderWithErr("GetDatasetNameByUUID error", tplModelArtsTrainJobVersionNew, &form) + return + } + dataPath := GetObsDataPathByUUID(uuid) + canNewJob, _ := canUserCreateTrainJobVersion(ctx, latestTask.UserID) if !canNewJob { ctx.RenderWithErr("user cann't new trainjob", tplModelArtsTrainJobVersionNew, &form) @@ -1386,6 +1416,7 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ EngineName: EngineName, PreVersionName: PreVersionName, TotalVersionCount: latestTask.TotalVersionCount + 1, + DatasetName: datasetName, } err = modelarts.GenerateTrainJobVersion(ctx, req, jobID) @@ -2420,3 +2451,39 @@ func TrainJobDownloadLogFile(ctx *context.Context) { ctx.Resp.Header().Set("Cache-Control", "max-age=0") http.Redirect(ctx.Resp, ctx.Req.Request, url, http.StatusMovedPermanently) } +func GetObsDataPathByUUID(uuid string) string { + var obsDataPath string + uuidList := strings.Split(uuid, ";") + for k, _ := range uuidList { + if k <= 0 { + obsDataPath = "/" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/" + } + if k > 0 { + obsDataPathNext := ";" + "s3://" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/" + obsDataPath = obsDataPath + obsDataPathNext + } + + } + return obsDataPath +} +func GetDatasetNameByUUID(uuid string) (string, error) { + uuidList := strings.Split(uuid, ";") + var datasetName string + for _, uuidStr := range uuidList { + attach, err := models.GetAttachmentByUUID(uuidStr) + if err != nil { + log.Error("GetAttachmentByUUID failed:%v", err) + return "", err + } + datasetName = datasetName + attach.Name + ";" + } + return datasetName, nil +} +func IsDatasetUseCountExceed(uuid string) bool { + uuidList := strings.Split(uuid, ";") + if len(uuidList) > setting.MaxDatasetNum { + return true + } else { + return false + } +}