Browse Source

Merge branch 'liuzx_trainjob' of https://git.openi.org.cn/OpenI/aiforge into liuzx_trainjob

pull/883/head
zhoupzh 3 years ago
parent
commit
bdb77758ac
7 changed files with 173 additions and 53 deletions
  1. +26
    -0
      models/cloudbrain.go
  2. +1
    -1
      modules/modelarts/modelarts.go
  3. +41
    -0
      modules/modelarts/resty.go
  4. +2
    -0
      routers/api/v1/api.go
  5. +102
    -0
      routers/api/v1/repo/modelarts.go
  6. +0
    -48
      routers/repo/modelarts.go
  7. +1
    -4
      routers/routes/routes.go

+ 26
- 0
models/cloudbrain.go View File

@@ -1149,6 +1149,32 @@ func deleteJob(e Engine, job *Cloudbrain) error {
return err
}

func DeleteJobVersion(job *Cloudbrain) error {
return deleteJobVersion(x, job)
}

func deleteJobVersion(e Engine, job *Cloudbrain) error {
_, err := e.ID(job.ID).Delete(job)
return err
}

// func DeleteJobVersion(job *Cloudbrain, jobID string, versionName string) error {
// return deleteJobVersion(x, job, jobID, versionName)
// }

// func deleteJobVersion(e Engine, job *Cloudbrain, jobID string, versionName string) error {
// var sess *xorm.Session
// sess = e.Where("job_id = ? AND version_name !=?", jobID, versionName)
// _, err := sess.Delete(job)
// return err
// }

// func deleteJobVersion(e Engine, jobID string, versionName string) error {
// deleteCloudbrainSql := "delete from cloudbrain where job_id=" + jobID + "and version_name=" + versionName
// _, err := e.Exec(deleteCloudbrainSql)
// return err
// }

func GetCloudbrainByName(jobName string) (*Cloudbrain, error) {
cb := &Cloudbrain{JobName: jobName}
return getRepoCloudBrain(cb)


+ 1
- 1
modules/modelarts/modelarts.go View File

@@ -41,7 +41,7 @@ const (
JobPath = "/job/"
OrderDesc = "desc" //向下查询
OrderAsc = "asc" //向上查询
Lines = 20
Lines = 500
TrainUrl = "train_url"
DataUrl = "data_url"
PerPage = 10


+ 41
- 0
modules/modelarts/resty.go View File

@@ -814,3 +814,44 @@ sendjob:

return &result, nil
}

func DelTrainJobVersion(jobID string, versionID string) (*models.TrainJobResult, error) {
checkSetting()
client := getRestyClient()
var result models.TrainJobResult

retry := 0

sendjob:
res, err := client.R().
SetAuthToken(TOKEN).
SetResult(&result).
Delete(HOST + "/v1/" + setting.ProjectID + urlTrainJob + "/" + jobID + "/versions/" + versionID)

if err != nil {
return &result, fmt.Errorf("resty DelTrainJobVersion: %v", err)
}

if res.StatusCode() == http.StatusUnauthorized && retry < 1 {
retry++
_ = getToken()
goto sendjob
}

if res.StatusCode() != http.StatusOK {
var temp models.ErrorResult
if err = json.Unmarshal([]byte(res.String()), &temp); err != nil {
log.Error("json.Unmarshal failed(%s): %v", res.String(), err.Error())
return &result, fmt.Errorf("json.Unmarshal failed(%s): %v", res.String(), err.Error())
}
log.Error("DelTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
return &result, fmt.Errorf("删除训练作业版本失败(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg)
}

if !result.IsSuccess {
log.Error("DelTrainJob(%s) failed", jobID)
return &result, fmt.Errorf("删除训练作业版本失败:%s", result.ErrorMsg)
}

return &result, nil
}

+ 2
- 0
routers/api/v1/api.go View File

@@ -878,6 +878,8 @@ func RegisterRoutes(m *macaron.Macaron) {
m.Get("", repo.GetModelArtsTrainJobVersion)
// m.Get("/log", repo.TrainJobGetLog)
m.Get("/log", repo.TrainJobGetLog)
m.Post("/del_version", repo.DelTrainJobVersion)
m.Post("/stop_version", repo.StopTrainJobVersion)
// m.Group("/:version-name", func() {
// m.Get("", repo.GetModelArtsTrainJobVersion)
// })


+ 102
- 0
routers/api/v1/repo/modelarts.go View File

@@ -13,6 +13,7 @@ import (
"code.gitea.io/gitea/modules/context"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/modelarts"
"code.gitea.io/gitea/modules/setting"
)

func GetModelArtsNotebook(ctx *context.APIContext) {
@@ -214,3 +215,104 @@ func trainJobGetLogContent(jobID string, versionName string, baseLine string, or

return resultLogFile, result, err
}

func DelTrainJobVersion(ctx *context.APIContext) {
var (
err error
)

var jobID = ctx.Params(":jobid")
var versionName = ctx.Query("version_name")
task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, versionName)
if err != nil {
log.Error("GetCloudbrainByJobID(%s) failed:%v", task.JobName, err.Error())
ctx.NotFound(err)
return
}

_, err = modelarts.DelTrainJobVersion(jobID, strconv.FormatInt(task.VersionID, 10))
if err != nil {
log.Error("DelTrainJobVersion(%s) failed:%v", task.JobName, err.Error())
ctx.NotFound(err)
return
}

err = models.DeleteJobVersion(task)
if err != nil {
ctx.ServerError("DeleteJobVersion failed", err)
ctx.NotFound(err)
return
}

//获取删除后的版本数量
repo := ctx.Repo.Repository
page := ctx.QueryInt("page")
if page <= 0 {
page = 1
}
_, VersionListCount, err := models.CloudbrainsVersionList(&models.CloudbrainsOptions{
ListOptions: models.ListOptions{
Page: page,
PageSize: setting.UI.IssuePagingNum,
},
RepoID: repo.ID,
Type: models.TypeCloudBrainTwo,
JobType: string(models.JobTypeTrain),
JobID: jobID,
})
if err != nil {
ctx.ServerError("get VersionListCount faild", err)
return
}

//判断当前的任务是否是最新版本的,若是,将V0001设置为最新版本,若不是,最新版本不变,更改最新版本的版本数。
if task.IsLatestVersion == modelarts.IsLatestVersion {
err = models.SetVersionCountAndLatestVersionByJobIDAndVersionName(jobID, modelarts.InitFatherVersionName, VersionListCount, modelarts.IsLatestVersion)
if err != nil {
ctx.ServerError("UpdateJobVersionCount failed", err)
return
}
} else {
latestTask, err := models.GetCloudbrainByJobIDAndIsLatestVersion(jobID, modelarts.IsLatestVersion)
if err != nil {
ctx.ServerError("GetCloudbrainByJobIDAndIsLatestVersion faild:", err)
return
}
err = models.SetVersionCountAndLatestVersionByJobIDAndVersionName(jobID, latestTask.VersionName, VersionListCount, modelarts.IsLatestVersion)
if err != nil {
ctx.ServerError("UpdateJobVersionCount failed", err)
return
}
}

ctx.JSON(http.StatusOK, map[string]interface{}{
"JobID": jobID,
"VersionName": versionName,
"StatusOK": 0,
})
}

func StopTrainJobVersion(ctx *context.APIContext) {
var (
err error
)
var jobID = ctx.Params(":jobid")
var versionName = ctx.Query("version_name")
task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, versionName)
if err != nil {
log.Error("GetCloudbrainByJobID(%s) failed:%v", task.JobName, err.Error())
return
}

_, err = modelarts.StopTrainJob(jobID, strconv.FormatInt(task.VersionID, 10))
if err != nil {
log.Error("StopTrainJob(%s) failed:%v", task.JobName, err.Error())
return
}

ctx.JSON(http.StatusOK, map[string]interface{}{
"JobID": jobID,
"VersionName": versionName,
"StatusOK": 0,
})
}

+ 0
- 48
routers/repo/modelarts.go View File

@@ -1145,54 +1145,6 @@ func TrainJobStop(ctx *context.Context) {
ctx.Redirect(setting.AppSubURL + ctx.Repo.RepoLink + "/modelarts/train-job")
}

func TrainJobVersionDel(ctx *context.Context) {
var jobID = ctx.Params(":jobid")
var versionName = ctx.Query(":versionName")
task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, versionName)
if err != nil {
log.Error("GetCloudbrainByJobID(%s) failed:%v", task.JobName, err.Error())
ctx.RenderWithErr(err.Error(), tplModelArtsTrainJobShow, nil)
return
}

_, err = modelarts.DelTrainJob(jobID)
if err != nil {
log.Error("DelTrainJob(%s) failed:%v", task.JobName, err.Error())
ctx.RenderWithErr(err.Error(), tplModelArtsTrainJobShow, nil)
return
}

err = models.DeleteJob(task)
if err != nil {
ctx.ServerError("DeleteJob failed", err)
return
}

// ctx.Redirect(setting.AppSubURL + ctx.Repo.RepoLink + "/modelarts/train-job")
ctx.HTML(http.StatusOK, tplModelArtsTrainJobShow)
}

func TrainJobVersionStop(ctx *context.Context) {
var jobID = ctx.Params(":jobid")
var versionName = ctx.Query(":versionName")
task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, versionName)
if err != nil {
log.Error("GetCloudbrainByJobID(%s) failed:%v", task.JobName, err.Error())
ctx.RenderWithErr(err.Error(), tplModelArtsTrainJobIndex, nil)
return
}

_, err = modelarts.StopTrainJob(jobID, strconv.FormatInt(task.VersionID, 10))
if err != nil {
log.Error("StopTrainJob(%s) failed:%v", task.JobName, err.Error())
ctx.RenderWithErr(err.Error(), tplModelArtsTrainJobIndex, nil)
return
}

// ctx.Redirect(setting.AppSubURL + ctx.Repo.RepoLink + "/modelarts/train-job")
ctx.HTML(http.StatusOK, tplModelArtsTrainJobShow)
}

func canUserCreateTrainJob(uid int64) (bool, error) {
org, err := models.GetOrgByName(setting.AllowedOrg)
if err != nil {


+ 1
- 4
routers/routes/routes.go View File

@@ -989,16 +989,13 @@ func RegisterRoutes(m *macaron.Macaron) {
m.Get("", reqRepoCloudBrainReader, repo.TrainJobShow)
m.Post("/stop", reqRepoCloudBrainWriter, repo.TrainJobStop)
m.Post("/del", reqRepoCloudBrainWriter, repo.TrainJobDel)
m.Get("/log", reqRepoCloudBrainReader, repo.TrainJobGetLog)
// m.Get("/log", reqRepoCloudBrainReader, repo.TrainJobGetLog)
m.Get("/models", reqRepoCloudBrainReader, repo.TrainJobShowModels)
m.Get("/download_model", reqRepoCloudBrainReader, repo.TrainJobDownloadModel)
m.Get("/version_models", reqRepoCloudBrainReader, repo.TrainJobVersionShowModels)
// m.Group("/:version-name", func() {
m.Get("/create_version", reqRepoCloudBrainReader, repo.TrainJobNewVersion)
m.Post("/create_version", reqRepoCloudBrainWriter, bindIgnErr(auth.CreateModelArtsTrainJobForm{}), repo.TrainJobCreateVersion)
// })
m.Post("/stop_version", reqRepoCloudBrainWriter, repo.TrainJobVersionStop)
m.Post("/del_version", reqRepoCloudBrainWriter, repo.TrainJobVersionDel)
})
m.Get("/create", reqRepoCloudBrainReader, repo.TrainJobNew)
m.Post("/create", reqRepoCloudBrainWriter, bindIgnErr(auth.CreateModelArtsTrainJobForm{}), repo.TrainJobCreate)


Loading…
Cancel
Save