diff --git a/models/cloudbrain.go b/models/cloudbrain.go index 1eaeffcd3..d3e93a437 100755 --- a/models/cloudbrain.go +++ b/models/cloudbrain.go @@ -1070,6 +1070,12 @@ type CreateInferenceJobParams struct { InfConfig InfConfig `json:"config"` WorkspaceID string `json:"workspace_id"` } +type CreateInfUserImageParams struct { + JobName string `json:"job_name"` + Description string `json:"job_desc"` + Config InfUserImageConfig `json:"config"` + WorkspaceID string `json:"workspace_id"` +} type InfConfig struct { WorkServerNum int `json:"worker_server_num"` @@ -1084,6 +1090,21 @@ type InfConfig struct { PoolID string `json:"pool_id"` } +type InfUserImageConfig struct { + WorkServerNum int `json:"worker_server_num"` + AppUrl string `json:"app_url"` //训练作业的代码目录 + BootFileUrl string `json:"boot_file_url"` //训练作业的代码启动文件,需要在代码目录下 + Parameter []Parameter `json:"parameter"` + DataUrl string `json:"data_url"` //训练作业需要的数据集OBS路径URL + EngineID int64 `json:"engine_id"` + LogUrl string `json:"log_url"` + CreateVersion bool `json:"create_version"` + Flavor Flavor `json:"flavor"` + PoolID string `json:"pool_id"` + UserImageUrl string `json:"user_image_url"` + UserCommand string `json:"user_command"` +} + type CreateTrainJobVersionParams struct { Description string `json:"job_desc"` Config TrainJobVersionConfig `json:"config"` @@ -2024,7 +2045,7 @@ func GetCloudbrainRunCountByRepoID(repoID int64) (int, error) { } func GetModelSafetyCountByUserID(userID int64) (int, error) { - count, err := x.In("status", JobWaiting, JobRunning,ModelArtsTrainJobInit,ModelArtsTrainJobImageCreating,ModelArtsTrainJobSubmitTrying,ModelArtsTrainJobScaling,ModelArtsTrainJobCheckInit,ModelArtsTrainJobCheckRunning,ModelArtsTrainJobCheckRunningCompleted).And("job_type = ? and user_id = ?", string(JobTypeModelSafety), userID).Count(new(Cloudbrain)) + count, err := x.In("status", JobWaiting, JobRunning, ModelArtsTrainJobInit, ModelArtsTrainJobImageCreating, ModelArtsTrainJobSubmitTrying, ModelArtsTrainJobScaling, ModelArtsTrainJobCheckInit, ModelArtsTrainJobCheckRunning, ModelArtsTrainJobCheckRunningCompleted).And("job_type = ? and user_id = ?", string(JobTypeModelSafety), userID).Count(new(Cloudbrain)) return int(count), err } diff --git a/modules/modelarts/modelarts.go b/modules/modelarts/modelarts.go index 06521993e..567f6d620 100755 --- a/modules/modelarts/modelarts.go +++ b/modules/modelarts/modelarts.go @@ -143,6 +143,8 @@ type GenerateInferenceJobReq struct { Spec *models.Specification DatasetName string JobType string + UserImageUrl string + UserCommand string } type VersionInfo struct { @@ -682,26 +684,51 @@ func GetOutputPathByCount(TotalVersionCount int) (VersionOutputPath string) { func GenerateInferenceJob(ctx *context.Context, req *GenerateInferenceJobReq) (err error) { createTime := timeutil.TimeStampNow() - jobResult, err := createInferenceJob(models.CreateInferenceJobParams{ - JobName: req.JobName, - Description: req.Description, - InfConfig: models.InfConfig{ - WorkServerNum: req.WorkServerNumber, - AppUrl: req.CodeObsPath, - BootFileUrl: req.BootFileUrl, - DataUrl: req.DataUrl, - EngineID: req.EngineID, - // TrainUrl: req.TrainUrl, - LogUrl: req.LogUrl, - PoolID: req.PoolID, - CreateVersion: true, - Flavor: models.Flavor{ - Code: req.Spec.SourceSpecId, + var jobResult *models.CreateTrainJobResult + var createErr error + if req.EngineID < 0 { + jobResult, createErr = createInferenceJobUserImage(models.CreateInfUserImageParams{ + JobName: req.JobName, + Description: req.Description, + Config: models.InfUserImageConfig{ + WorkServerNum: req.WorkServerNumber, + AppUrl: req.CodeObsPath, + BootFileUrl: req.BootFileUrl, + DataUrl: req.DataUrl, + // TrainUrl: req.TrainUrl, + LogUrl: req.LogUrl, + PoolID: req.PoolID, + CreateVersion: true, + Flavor: models.Flavor{ + Code: req.Spec.SourceSpecId, + }, + Parameter: req.Parameters, + UserImageUrl: req.UserImageUrl, + UserCommand: req.UserCommand, }, - Parameter: req.Parameters, - }, - }) - if err != nil { + }) + } else { + jobResult, createErr = createInferenceJob(models.CreateInferenceJobParams{ + JobName: req.JobName, + Description: req.Description, + InfConfig: models.InfConfig{ + WorkServerNum: req.WorkServerNumber, + AppUrl: req.CodeObsPath, + BootFileUrl: req.BootFileUrl, + DataUrl: req.DataUrl, + EngineID: req.EngineID, + // TrainUrl: req.TrainUrl, + LogUrl: req.LogUrl, + PoolID: req.PoolID, + CreateVersion: true, + Flavor: models.Flavor{ + Code: req.Spec.SourceSpecId, + }, + Parameter: req.Parameters, + }, + }) + } + if createErr != nil { log.Error("createInferenceJob failed: %v", err.Error()) if strings.HasPrefix(err.Error(), UnknownErrorPrefix) { log.Info("(%s)unknown error, set temp status", req.DisplayJobName) diff --git a/modules/modelarts/resty.go b/modules/modelarts/resty.go index fd1c467f3..c38300606 100755 --- a/modules/modelarts/resty.go +++ b/modules/modelarts/resty.go @@ -1197,6 +1197,66 @@ sendjob: return &result, nil } +func createInferenceJobUserImage(createJobParams models.CreateInfUserImageParams) (*models.CreateTrainJobResult, error) { + checkSetting() + client := getRestyClient() + var result models.CreateTrainJobResult + + retry := 0 + +sendjob: + res, err := client.R(). + SetHeader("Content-Type", "application/json"). + SetAuthToken(TOKEN). + SetBody(createJobParams). + SetResult(&result). + Post(HOST + "/v1/" + setting.ProjectID + urlTrainJob) + + if err != nil { + return nil, fmt.Errorf("resty create train-job: %s", err) + } + + req, _ := json.Marshal(createJobParams) + log.Info("%s", req) + + if res.StatusCode() == http.StatusUnauthorized && retry < 1 { + retry++ + _ = getToken() + goto sendjob + } + + if res.StatusCode() != http.StatusOK { + var temp models.ErrorResult + if err = json.Unmarshal([]byte(res.String()), &temp); err != nil { + log.Error("json.Unmarshal failed(%s): %v", res.String(), err.Error()) + return &result, fmt.Errorf("json.Unmarshal failed(%s): %v", res.String(), err.Error()) + } + log.Error("createInferenceJobUserImage failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) + bootFileErrorMsg := "Invalid OBS path '" + createJobParams.Config.BootFileUrl + "'." + dataSetErrorMsg := "Invalid OBS path '" + createJobParams.Config.DataUrl + "'." + if temp.ErrorMsg == bootFileErrorMsg { + log.Error("启动文件错误!createInferenceJobUserImage failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) + return &result, fmt.Errorf("启动文件错误!") + } + if temp.ErrorMsg == dataSetErrorMsg { + log.Error("数据集错误!createInferenceJobUserImage failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) + return &result, fmt.Errorf("数据集错误!") + } + if res.StatusCode() == http.StatusBadGateway { + return &result, fmt.Errorf(UnknownErrorPrefix+"createInferenceJobUserImage failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) + } else { + return &result, fmt.Errorf("createInferenceJobUserImage failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) + } + } + + if !result.IsSuccess { + log.Error("createInferenceJobUserImage failed(%s): %s", result.ErrorCode, result.ErrorMsg) + return &result, fmt.Errorf("createInferenceJobUserImage failed(%s): %s", result.ErrorCode, result.ErrorMsg) + } + + return &result, nil +} + func createNotebook2(createJobParams models.CreateNotebook2Params) (*models.CreateNotebookResult, error) { checkSetting() client := getRestyClient() diff --git a/routers/repo/modelarts.go b/routers/repo/modelarts.go index 07c1fcd3e..be59b0f3f 100755 --- a/routers/repo/modelarts.go +++ b/routers/repo/modelarts.go @@ -1312,6 +1312,36 @@ func getUserCommand(engineId int, req *modelarts.GenerateTrainJobReq) (string, s return userCommand, userImageUrl } +func getInfJobUserCommand(engineId int, req *modelarts.GenerateInferenceJobReq) (string, string) { + userImageUrl := "" + userCommand := "" + if engineId < 0 { + tmpCodeObsPath := strings.Trim(req.CodeObsPath, "/") + tmpCodeObsPaths := strings.Split(tmpCodeObsPath, "/") + lastCodeDir := "code" + if len(tmpCodeObsPaths) > 0 { + lastCodeDir = tmpCodeObsPaths[len(tmpCodeObsPaths)-1] + } + userCommand = "/bin/bash /home/work/run_train.sh 's3://" + req.CodeObsPath + "' '" + lastCodeDir + "/" + req.BootFile + "' '/tmp/log/train.log' --'data_url'='s3://" + req.DataUrl + "' --'train_url'='s3://" + req.TrainUrl + "'" + var versionInfos modelarts.VersionInfo + if err := json.Unmarshal([]byte(setting.EngineVersions), &versionInfos); err != nil { + log.Info("json parse err." + err.Error()) + } else { + for _, engine := range versionInfos.Version { + if engine.ID == engineId { + userImageUrl = engine.Url + break + } + } + } + for _, param := range req.Parameters { + userCommand += " --'" + param.Label + "'='" + param.Value + "'" + } + return userCommand, userImageUrl + } + return userCommand, userImageUrl +} + func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) { ctx.Data["PageIsTrainJob"] = true var jobID = ctx.Params(":jobid") @@ -2171,6 +2201,10 @@ func InferenceJobCreate(ctx *context.Context, form auth.CreateModelArtsInference JobType: string(models.JobTypeInference), } + userCommand, userImageUrl := getInfJobUserCommand(engineID, req) + req.UserCommand = userCommand + req.UserImageUrl = userImageUrl + err = modelarts.GenerateInferenceJob(ctx, req) if err != nil { log.Error("GenerateTrainJob failed:%v", err.Error())