diff --git a/models/cloudbrain.go b/models/cloudbrain.go index 1210bfb7e..d3bf77922 100755 --- a/models/cloudbrain.go +++ b/models/cloudbrain.go @@ -499,8 +499,32 @@ type Config struct { LogUrl string `json:"log_url"` //UserImageUrl string `json:"user_image_url"` //UserCommand string `json:"user_command"` - CreateVersion bool `json:"create_version"` - Volumes []Volumes `json:"volumes"` + //CreateVersion bool `json:"create_version"` + //Volumes []Volumes `json:"volumes"` + Flavor Flavor `json:"flavor"` + PoolID string `json:"pool_id"` +} + +type CreateConfigParams struct { + ConfigName string `json:"config_name"` + Description string `json:"config_desc"` + WorkServerNum int `json:"worker_server_num"` + AppUrl string `json:"app_url"` //训练作业的代码目录 + BootFileUrl string `json:"boot_file_url"` //训练作业的代码启动文件,需要在代码目录下 + Parameter []Parameter `json:"parameter"` + DataUrl string `json:"data_url"` //训练作业需要的数据集OBS路径URL + //DatasetID string `json:"dataset_id"` + //DataVersionID string `json:"dataset_version_id"` + //DataSource []DataSource `json:"data_source"` + //SpecID int64 `json:"spec_id"` + EngineID int64 `json:"engine_id"` + //ModelID int64 `json:"model_id"` + TrainUrl string `json:"train_url"` //训练作业的输出文件OBS路径URL + LogUrl string `json:"log_url"` + //UserImageUrl string `json:"user_image_url"` + //UserCommand string `json:"user_command"` + //CreateVersion bool `json:"create_version"` + //Volumes []Volumes `json:"volumes"` Flavor Flavor `json:"flavor"` PoolID string `json:"pool_id"` } @@ -552,6 +576,12 @@ type CreateTrainJobResult struct { VersionName string `json:"version_name"` } +type CreateTrainJobConfigResult struct { + ErrorCode string `json:"error_code"` + ErrorMsg string `json:"error_msg"` + IsSuccess bool `json:"is_success"` +} + type GetResourceSpecsResult struct { ErrorCode string `json:"error_code"` ErrorMsg string `json:"error_msg"` @@ -574,6 +604,12 @@ type Specs struct { InterfaceType int `json:"interface_type"` } +type ErrorResult struct { + ErrorCode string `json:"error_code"` + ErrorMsg string `json:"error_message"` + IsSuccess bool `json:"is_success"` +} + func Cloudbrains(opts *CloudbrainsOptions) ([]*Cloudbrain, int64, error) { sess := x.NewSession() defer sess.Close() diff --git a/modules/modelarts/resty.go b/modules/modelarts/resty.go index 2587eaa5d..a9c1ffa76 100755 --- a/modules/modelarts/resty.go +++ b/modules/modelarts/resty.go @@ -25,6 +25,7 @@ const ( urlNotebook = "/demanager/instances" urlTrainJob = "/training-jobs" urlResourceSpecs = "/job/resource-specs" + urlTrainJobConfig = "/training-job-configs" errorCodeExceedLimit = "ModelArts.0118" ) @@ -86,7 +87,6 @@ func getToken() error { } TOKEN = res.Header().Get("X-Subject-Token") - log.Info(TOKEN) return nil } @@ -296,8 +296,6 @@ func createTrainJob(createJobParams models.CreateTrainJobParams) (*models.Create client := getRestyClient() var result models.CreateTrainJobResult - log.Info("%+v",createJobParams) - retry := 0 sendjob: @@ -323,8 +321,13 @@ sendjob: } if res.StatusCode() != http.StatusOK { - log.Error("createTrainJob failed", res.StatusCode(), res.RawResponse.Body, result.ErrorCode, result.ErrorMsg) - return &result, fmt.Errorf("createTrainJob failed(%d)", res.StatusCode()) + var temp models.ErrorResult + if err = json.Unmarshal([]byte(res.String()), &temp); err != nil { + log.Error("json.Unmarshal failed(%s): %v", res.String(), err.Error()) + return &result, fmt.Errorf("json.Unmarshal failed(%s): %v", res.String(), err.Error()) + } + log.Error("createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) + return &result, fmt.Errorf("createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) } if !result.IsSuccess { @@ -360,8 +363,13 @@ sendjob: } if res.StatusCode() != http.StatusOK { - log.Error("GetResourceSpecs failed(%d)", res.StatusCode()) - return &result, fmt.Errorf("GetResourceSpecs failed(%d)", res.StatusCode()) + var temp models.ErrorResult + if err = json.Unmarshal([]byte(res.String()), &temp); err != nil { + log.Error("json.Unmarshal failed(%s): %v", res.String(), err.Error()) + return &result, fmt.Errorf("json.Unmarshal failed(%s): %v", res.String(), err.Error()) + } + log.Error("GetResourceSpecs failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) + return &result, fmt.Errorf("GetResourceSpecs failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) } if !result.IsSuccess { @@ -371,3 +379,49 @@ sendjob: return &result, nil } + +func CreateTrainJobConfig(req models.CreateConfigParams) (*models.CreateTrainJobConfigResult, error) { + checkSetting() + client := getRestyClient() + var result models.CreateTrainJobConfigResult + + retry := 0 + +sendjob: + res, err := client.R(). + SetHeader("Content-Type", "application/json"). + SetAuthToken(TOKEN). + SetBody(req). + SetResult(&result). + Post(HOST + "/v1/" + setting.ProjectID + urlTrainJobConfig) + + if err != nil { + return nil, fmt.Errorf("resty CreateTrainJobConfig: %s", err) + } + + if res.StatusCode() == http.StatusUnauthorized && retry < 1 { + retry++ + _ = getToken() + goto sendjob + } + + temp, _ := json.Marshal(req) + log.Info("%s", temp) + + if res.StatusCode() != http.StatusOK { + var temp models.ErrorResult + if err = json.Unmarshal([]byte(res.String()), &temp); err != nil { + log.Error("json.Unmarshal failed(%s): %v", res.String(), err.Error()) + return &result, fmt.Errorf("json.Unmarshal failed(%s): %v", res.String(), err.Error()) + } + log.Error("CreateTrainJobConfig failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) + return &result, fmt.Errorf("CreateTrainJobConfig failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) + } + + if !result.IsSuccess { + log.Error("CreateTrainJobConfig failed(%s): %s", result.ErrorCode, result.ErrorMsg) + return &result, fmt.Errorf("CreateTrainJobConfig failed(%s): %s", result.ErrorCode, result.ErrorMsg) + } + + return &result, nil +} diff --git a/routers/repo/modelarts.go b/routers/repo/modelarts.go index 0c442af8b..7edb62c94 100755 --- a/routers/repo/modelarts.go +++ b/routers/repo/modelarts.go @@ -354,10 +354,6 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) return } - if isSaveParam == "on" { - //todo: save param - } - if err := git.Clone(repo.RepoPath(), codeLocalPath, git.CloneRepoOptions{}); err != nil { log.Error("Failed to clone repository: %s (%v)", repo.FullName(), err) ctx.RenderWithErr("Failed to clone repository", tplModelArtsTrainJobNew, &form) @@ -383,6 +379,39 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) return } + if isSaveParam == "on" { + if form.ParameterTemplateName == "" { + log.Error("ParameterTemplateName is empty") + ctx.RenderWithErr("保存作业参数时,作业参数名称不能为空", tplModelArtsTrainJobNew, &form) + return + } + + _, err := modelarts.CreateTrainJobConfig(models.CreateConfigParams{ + ConfigName: form.ParameterTemplateName, + Description: form.PrameterDescription, + DataUrl: dataPath, + AppUrl: codeObsPath, + BootFileUrl: codeObsPath + bootFile, + TrainUrl: outputObsPath, + Flavor: models.Flavor{ + Code: flavorCode, + }, + WorkServerNum: workServerNumber, + EngineID: int64(engineID), + LogUrl: logObsPath, + PoolID: poolID, + Parameter: []models.Parameter{ + + }, + }) + + if err != nil { + log.Error("Failed to CreateTrainJobConfig: %v", err) + ctx.RenderWithErr("保存作业参数失败:" + err.Error(), tplModelArtsTrainJobNew, &form) + return + } + } + req := &modelarts.GenerateTrainJobReq{ JobName: jobName, DataUrl: dataPath,