@@ -499,8 +499,32 @@ type Config struct { | |||
LogUrl string `json:"log_url"` | |||
//UserImageUrl string `json:"user_image_url"` | |||
//UserCommand string `json:"user_command"` | |||
CreateVersion bool `json:"create_version"` | |||
Volumes []Volumes `json:"volumes"` | |||
//CreateVersion bool `json:"create_version"` | |||
//Volumes []Volumes `json:"volumes"` | |||
Flavor Flavor `json:"flavor"` | |||
PoolID string `json:"pool_id"` | |||
} | |||
type CreateConfigParams struct { | |||
ConfigName string `json:"config_name"` | |||
Description string `json:"config_desc"` | |||
WorkServerNum int `json:"worker_server_num"` | |||
AppUrl string `json:"app_url"` //训练作业的代码目录 | |||
BootFileUrl string `json:"boot_file_url"` //训练作业的代码启动文件,需要在代码目录下 | |||
Parameter []Parameter `json:"parameter"` | |||
DataUrl string `json:"data_url"` //训练作业需要的数据集OBS路径URL | |||
//DatasetID string `json:"dataset_id"` | |||
//DataVersionID string `json:"dataset_version_id"` | |||
//DataSource []DataSource `json:"data_source"` | |||
//SpecID int64 `json:"spec_id"` | |||
EngineID int64 `json:"engine_id"` | |||
//ModelID int64 `json:"model_id"` | |||
TrainUrl string `json:"train_url"` //训练作业的输出文件OBS路径URL | |||
LogUrl string `json:"log_url"` | |||
//UserImageUrl string `json:"user_image_url"` | |||
//UserCommand string `json:"user_command"` | |||
//CreateVersion bool `json:"create_version"` | |||
//Volumes []Volumes `json:"volumes"` | |||
Flavor Flavor `json:"flavor"` | |||
PoolID string `json:"pool_id"` | |||
} | |||
@@ -552,6 +576,12 @@ type CreateTrainJobResult struct { | |||
VersionName string `json:"version_name"` | |||
} | |||
type CreateTrainJobConfigResult struct { | |||
ErrorCode string `json:"error_code"` | |||
ErrorMsg string `json:"error_msg"` | |||
IsSuccess bool `json:"is_success"` | |||
} | |||
type GetResourceSpecsResult struct { | |||
ErrorCode string `json:"error_code"` | |||
ErrorMsg string `json:"error_msg"` | |||
@@ -574,6 +604,12 @@ type Specs struct { | |||
InterfaceType int `json:"interface_type"` | |||
} | |||
type ErrorResult struct { | |||
ErrorCode string `json:"error_code"` | |||
ErrorMsg string `json:"error_message"` | |||
IsSuccess bool `json:"is_success"` | |||
} | |||
func Cloudbrains(opts *CloudbrainsOptions) ([]*Cloudbrain, int64, error) { | |||
sess := x.NewSession() | |||
defer sess.Close() | |||
@@ -25,6 +25,7 @@ const ( | |||
urlNotebook = "/demanager/instances" | |||
urlTrainJob = "/training-jobs" | |||
urlResourceSpecs = "/job/resource-specs" | |||
urlTrainJobConfig = "/training-job-configs" | |||
errorCodeExceedLimit = "ModelArts.0118" | |||
) | |||
@@ -86,7 +87,6 @@ func getToken() error { | |||
} | |||
TOKEN = res.Header().Get("X-Subject-Token") | |||
log.Info(TOKEN) | |||
return nil | |||
} | |||
@@ -296,8 +296,6 @@ func createTrainJob(createJobParams models.CreateTrainJobParams) (*models.Create | |||
client := getRestyClient() | |||
var result models.CreateTrainJobResult | |||
log.Info("%+v",createJobParams) | |||
retry := 0 | |||
sendjob: | |||
@@ -323,8 +321,13 @@ sendjob: | |||
} | |||
if res.StatusCode() != http.StatusOK { | |||
log.Error("createTrainJob failed", res.StatusCode(), res.RawResponse.Body, result.ErrorCode, result.ErrorMsg) | |||
return &result, fmt.Errorf("createTrainJob failed(%d)", res.StatusCode()) | |||
var temp models.ErrorResult | |||
if err = json.Unmarshal([]byte(res.String()), &temp); err != nil { | |||
log.Error("json.Unmarshal failed(%s): %v", res.String(), err.Error()) | |||
return &result, fmt.Errorf("json.Unmarshal failed(%s): %v", res.String(), err.Error()) | |||
} | |||
log.Error("createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) | |||
return &result, fmt.Errorf("createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) | |||
} | |||
if !result.IsSuccess { | |||
@@ -360,8 +363,13 @@ sendjob: | |||
} | |||
if res.StatusCode() != http.StatusOK { | |||
log.Error("GetResourceSpecs failed(%d)", res.StatusCode()) | |||
return &result, fmt.Errorf("GetResourceSpecs failed(%d)", res.StatusCode()) | |||
var temp models.ErrorResult | |||
if err = json.Unmarshal([]byte(res.String()), &temp); err != nil { | |||
log.Error("json.Unmarshal failed(%s): %v", res.String(), err.Error()) | |||
return &result, fmt.Errorf("json.Unmarshal failed(%s): %v", res.String(), err.Error()) | |||
} | |||
log.Error("GetResourceSpecs failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) | |||
return &result, fmt.Errorf("GetResourceSpecs failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) | |||
} | |||
if !result.IsSuccess { | |||
@@ -371,3 +379,49 @@ sendjob: | |||
return &result, nil | |||
} | |||
func CreateTrainJobConfig(req models.CreateConfigParams) (*models.CreateTrainJobConfigResult, error) { | |||
checkSetting() | |||
client := getRestyClient() | |||
var result models.CreateTrainJobConfigResult | |||
retry := 0 | |||
sendjob: | |||
res, err := client.R(). | |||
SetHeader("Content-Type", "application/json"). | |||
SetAuthToken(TOKEN). | |||
SetBody(req). | |||
SetResult(&result). | |||
Post(HOST + "/v1/" + setting.ProjectID + urlTrainJobConfig) | |||
if err != nil { | |||
return nil, fmt.Errorf("resty CreateTrainJobConfig: %s", err) | |||
} | |||
if res.StatusCode() == http.StatusUnauthorized && retry < 1 { | |||
retry++ | |||
_ = getToken() | |||
goto sendjob | |||
} | |||
temp, _ := json.Marshal(req) | |||
log.Info("%s", temp) | |||
if res.StatusCode() != http.StatusOK { | |||
var temp models.ErrorResult | |||
if err = json.Unmarshal([]byte(res.String()), &temp); err != nil { | |||
log.Error("json.Unmarshal failed(%s): %v", res.String(), err.Error()) | |||
return &result, fmt.Errorf("json.Unmarshal failed(%s): %v", res.String(), err.Error()) | |||
} | |||
log.Error("CreateTrainJobConfig failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) | |||
return &result, fmt.Errorf("CreateTrainJobConfig failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) | |||
} | |||
if !result.IsSuccess { | |||
log.Error("CreateTrainJobConfig failed(%s): %s", result.ErrorCode, result.ErrorMsg) | |||
return &result, fmt.Errorf("CreateTrainJobConfig failed(%s): %s", result.ErrorCode, result.ErrorMsg) | |||
} | |||
return &result, nil | |||
} |
@@ -354,10 +354,6 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) | |||
return | |||
} | |||
if isSaveParam == "on" { | |||
//todo: save param | |||
} | |||
if err := git.Clone(repo.RepoPath(), codeLocalPath, git.CloneRepoOptions{}); err != nil { | |||
log.Error("Failed to clone repository: %s (%v)", repo.FullName(), err) | |||
ctx.RenderWithErr("Failed to clone repository", tplModelArtsTrainJobNew, &form) | |||
@@ -383,6 +379,39 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) | |||
return | |||
} | |||
if isSaveParam == "on" { | |||
if form.ParameterTemplateName == "" { | |||
log.Error("ParameterTemplateName is empty") | |||
ctx.RenderWithErr("保存作业参数时,作业参数名称不能为空", tplModelArtsTrainJobNew, &form) | |||
return | |||
} | |||
_, err := modelarts.CreateTrainJobConfig(models.CreateConfigParams{ | |||
ConfigName: form.ParameterTemplateName, | |||
Description: form.PrameterDescription, | |||
DataUrl: dataPath, | |||
AppUrl: codeObsPath, | |||
BootFileUrl: codeObsPath + bootFile, | |||
TrainUrl: outputObsPath, | |||
Flavor: models.Flavor{ | |||
Code: flavorCode, | |||
}, | |||
WorkServerNum: workServerNumber, | |||
EngineID: int64(engineID), | |||
LogUrl: logObsPath, | |||
PoolID: poolID, | |||
Parameter: []models.Parameter{ | |||
}, | |||
}) | |||
if err != nil { | |||
log.Error("Failed to CreateTrainJobConfig: %v", err) | |||
ctx.RenderWithErr("保存作业参数失败:" + err.Error(), tplModelArtsTrainJobNew, &form) | |||
return | |||
} | |||
} | |||
req := &modelarts.GenerateTrainJobReq{ | |||
JobName: jobName, | |||
DataUrl: dataPath, | |||