Browse Source

Merge branch 'train-job' of https://git.openi.org.cn/OpenI/aiforge into train-job

pull/625/head
Gitea 4 years ago
parent
commit
8a9ccc3a7b
3 changed files with 132 additions and 13 deletions
  1. +38
    -2
      models/cloudbrain.go
  2. +61
    -7
      modules/modelarts/resty.go
  3. +33
    -4
      routers/repo/modelarts.go

+ 38
- 2
models/cloudbrain.go View File

@@ -499,8 +499,32 @@ type Config struct {
LogUrl string `json:"log_url"`
//UserImageUrl string `json:"user_image_url"`
//UserCommand string `json:"user_command"`
CreateVersion bool `json:"create_version"`
Volumes []Volumes `json:"volumes"`
//CreateVersion bool `json:"create_version"`
//Volumes []Volumes `json:"volumes"`
Flavor Flavor `json:"flavor"`
PoolID string `json:"pool_id"`
}

type CreateConfigParams struct {
ConfigName string `json:"config_name"`
Description string `json:"config_desc"`
WorkServerNum int `json:"worker_server_num"`
AppUrl string `json:"app_url"` //训练作业的代码目录
BootFileUrl string `json:"boot_file_url"` //训练作业的代码启动文件,需要在代码目录下
Parameter []Parameter `json:"parameter"`
DataUrl string `json:"data_url"` //训练作业需要的数据集OBS路径URL
//DatasetID string `json:"dataset_id"`
//DataVersionID string `json:"dataset_version_id"`
//DataSource []DataSource `json:"data_source"`
//SpecID int64 `json:"spec_id"`
EngineID int64 `json:"engine_id"`
//ModelID int64 `json:"model_id"`
TrainUrl string `json:"train_url"` //训练作业的输出文件OBS路径URL
LogUrl string `json:"log_url"`
//UserImageUrl string `json:"user_image_url"`
//UserCommand string `json:"user_command"`
//CreateVersion bool `json:"create_version"`
//Volumes []Volumes `json:"volumes"`
Flavor Flavor `json:"flavor"`
PoolID string `json:"pool_id"`
}
@@ -552,6 +576,12 @@ type CreateTrainJobResult struct {
VersionName string `json:"version_name"`
}

type CreateTrainJobConfigResult struct {
ErrorCode string `json:"error_code"`
ErrorMsg string `json:"error_msg"`
IsSuccess bool `json:"is_success"`
}

type GetResourceSpecsResult struct {
ErrorCode string `json:"error_code"`
ErrorMsg string `json:"error_msg"`
@@ -574,6 +604,12 @@ type Specs struct {
InterfaceType int `json:"interface_type"`
}

type ErrorResult struct {
ErrorCode string `json:"error_code"`
ErrorMsg string `json:"error_message"`
IsSuccess bool `json:"is_success"`
}

func Cloudbrains(opts *CloudbrainsOptions) ([]*Cloudbrain, int64, error) {
sess := x.NewSession()
defer sess.Close()


+ 61
- 7
modules/modelarts/resty.go View File

@@ -25,6 +25,7 @@ const (
urlNotebook = "/demanager/instances"
urlTrainJob = "/training-jobs"
urlResourceSpecs = "/job/resource-specs"
urlTrainJobConfig = "/training-job-configs"

errorCodeExceedLimit = "ModelArts.0118"
)
@@ -86,7 +87,6 @@ func getToken() error {
}

TOKEN = res.Header().Get("X-Subject-Token")
log.Info(TOKEN)

return nil
}
@@ -296,8 +296,6 @@ func createTrainJob(createJobParams models.CreateTrainJobParams) (*models.Create
client := getRestyClient()
var result models.CreateTrainJobResult

log.Info("%+v",createJobParams)

retry := 0

sendjob:
@@ -323,8 +321,13 @@ sendjob:
}

if res.StatusCode() != http.StatusOK {
log.Error("createTrainJob failed", res.StatusCode(), res.RawResponse.Body, result.ErrorCode, result.ErrorMsg)
return &result, fmt.Errorf("createTrainJob failed(%d)", res.StatusCode())
var temp models.ErrorResult
if err = json.Unmarshal([]byte(res.String()), &temp); err != nil {
log.Error("json.Unmarshal failed(%s): %v", res.String(), err.Error())
return &result, fmt.Errorf("json.Unmarshal failed(%s): %v", res.String(), err.Error())
}
log.Error("createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
return &result, fmt.Errorf("createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
}

if !result.IsSuccess {
@@ -360,8 +363,13 @@ sendjob:
}

if res.StatusCode() != http.StatusOK {
log.Error("GetResourceSpecs failed(%d)", res.StatusCode())
return &result, fmt.Errorf("GetResourceSpecs failed(%d)", res.StatusCode())
var temp models.ErrorResult
if err = json.Unmarshal([]byte(res.String()), &temp); err != nil {
log.Error("json.Unmarshal failed(%s): %v", res.String(), err.Error())
return &result, fmt.Errorf("json.Unmarshal failed(%s): %v", res.String(), err.Error())
}
log.Error("GetResourceSpecs failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
return &result, fmt.Errorf("GetResourceSpecs failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
}

if !result.IsSuccess {
@@ -371,3 +379,49 @@ sendjob:

return &result, nil
}

func CreateTrainJobConfig(req models.CreateConfigParams) (*models.CreateTrainJobConfigResult, error) {
checkSetting()
client := getRestyClient()
var result models.CreateTrainJobConfigResult

retry := 0

sendjob:
res, err := client.R().
SetHeader("Content-Type", "application/json").
SetAuthToken(TOKEN).
SetBody(req).
SetResult(&result).
Post(HOST + "/v1/" + setting.ProjectID + urlTrainJobConfig)

if err != nil {
return nil, fmt.Errorf("resty CreateTrainJobConfig: %s", err)
}

if res.StatusCode() == http.StatusUnauthorized && retry < 1 {
retry++
_ = getToken()
goto sendjob
}

temp, _ := json.Marshal(req)
log.Info("%s", temp)

if res.StatusCode() != http.StatusOK {
var temp models.ErrorResult
if err = json.Unmarshal([]byte(res.String()), &temp); err != nil {
log.Error("json.Unmarshal failed(%s): %v", res.String(), err.Error())
return &result, fmt.Errorf("json.Unmarshal failed(%s): %v", res.String(), err.Error())
}
log.Error("CreateTrainJobConfig failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
return &result, fmt.Errorf("CreateTrainJobConfig failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
}

if !result.IsSuccess {
log.Error("CreateTrainJobConfig failed(%s): %s", result.ErrorCode, result.ErrorMsg)
return &result, fmt.Errorf("CreateTrainJobConfig failed(%s): %s", result.ErrorCode, result.ErrorMsg)
}

return &result, nil
}

+ 33
- 4
routers/repo/modelarts.go View File

@@ -354,10 +354,6 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm)
return
}

if isSaveParam == "on" {
//todo: save param
}

if err := git.Clone(repo.RepoPath(), codeLocalPath, git.CloneRepoOptions{}); err != nil {
log.Error("Failed to clone repository: %s (%v)", repo.FullName(), err)
ctx.RenderWithErr("Failed to clone repository", tplModelArtsTrainJobNew, &form)
@@ -383,6 +379,39 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm)
return
}

if isSaveParam == "on" {
if form.ParameterTemplateName == "" {
log.Error("ParameterTemplateName is empty")
ctx.RenderWithErr("保存作业参数时,作业参数名称不能为空", tplModelArtsTrainJobNew, &form)
return
}

_, err := modelarts.CreateTrainJobConfig(models.CreateConfigParams{
ConfigName: form.ParameterTemplateName,
Description: form.PrameterDescription,
DataUrl: dataPath,
AppUrl: codeObsPath,
BootFileUrl: codeObsPath + bootFile,
TrainUrl: outputObsPath,
Flavor: models.Flavor{
Code: flavorCode,
},
WorkServerNum: workServerNumber,
EngineID: int64(engineID),
LogUrl: logObsPath,
PoolID: poolID,
Parameter: []models.Parameter{

},
})

if err != nil {
log.Error("Failed to CreateTrainJobConfig: %v", err)
ctx.RenderWithErr("保存作业参数失败:" + err.Error(), tplModelArtsTrainJobNew, &form)
return
}
}

req := &modelarts.GenerateTrainJobReq{
JobName: jobName,
DataUrl: dataPath,


Loading…
Cancel
Save