Browse Source

Merge pull request 'trainjob fix bug' (#962) from liuzx_trainjob into V20211115

Reviewed-on: https://git.openi.org.cn/OpenI/aiforge/pulls/962
Reviewed-by: zhoupzh <zhoupzh@pcl.ac.cn>
pull/963/head
zhoupzh 3 years ago
parent
commit
b17282db8f
5 changed files with 28 additions and 100 deletions
  1. +0
    -17
      modules/storage/obs.go
  2. +0
    -1
      routers/api/v1/api.go
  3. +0
    -61
      routers/api/v1/repo/modelarts.go
  4. +20
    -13
      routers/repo/modelarts.go
  5. +8
    -8
      templates/repo/modelarts/trainjob/show.tmpl

+ 0
- 17
modules/storage/obs.go View File

@@ -337,20 +337,3 @@ func ObsCreateObject(path string) error {


return nil return nil
} }

func ObsDownloadAFile(bucket string, key string) (io.ReadCloser, error) {
input := &obs.GetObjectInput{}
input.Bucket = bucket
input.Key = key
output, err := ObsCli.GetObject(input)
if err == nil {
log.Info("StorageClass:%s, ETag:%s, ContentType:%s, ContentLength:%d, LastModified:%s\n",
output.StorageClass, output.ETag, output.ContentType, output.ContentLength, output.LastModified)
return output.Body, nil
} else if obsError, ok := err.(obs.ObsError); ok {
log.Error("Code:%s, Message:%s", obsError.Code, obsError.Message)
return nil, obsError
} else {
return nil, err
}
}

+ 0
- 1
routers/api/v1/api.go View File

@@ -880,7 +880,6 @@ func RegisterRoutes(m *macaron.Macaron) {
m.Post("/del_version", repo.DelTrainJobVersion) m.Post("/del_version", repo.DelTrainJobVersion)
m.Post("/stop_version", repo.StopTrainJobVersion) m.Post("/stop_version", repo.StopTrainJobVersion)
m.Get("/model_list", repo.ModelList) m.Get("/model_list", repo.ModelList)
m.Get("/model_download", repo.ModelDownload)
}) })
}) })
}, reqRepoReader(models.UnitTypeCloudBrain)) }, reqRepoReader(models.UnitTypeCloudBrain))


+ 0
- 61
routers/api/v1/repo/modelarts.go View File

@@ -7,7 +7,6 @@ package repo


import ( import (
"net/http" "net/http"
"path"
"strconv" "strconv"
"strings" "strings"


@@ -15,7 +14,6 @@ import (
"code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/context"
"code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/modelarts" "code.gitea.io/gitea/modules/modelarts"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/storage" "code.gitea.io/gitea/modules/storage"
) )


@@ -321,62 +319,3 @@ func ModelList(ctx *context.APIContext) {
"PageIsCloudBrain": true, "PageIsCloudBrain": true,
}) })
} }

func ModelDownload(ctx *context.Context) {
var (
err error
)

var jobID = ctx.Params(":jobid")
versionName := ctx.Query("version_name")
// versionName := "V0001"
parentDir := ctx.Query("parent_dir")
fileName := ctx.Query("file_name")
log.Info("DownloadSingleModelFile start.")
// id := ctx.Params(":ID")
// path := Model_prefix + models.AttachmentRelativePath(id) + "/" + parentDir + fileName
task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, versionName)
if err != nil {
log.Error("GetCloudbrainByJobID(%s) failed:%v", task.JobName, err.Error())
return
}

path := strings.TrimPrefix(path.Join(setting.TrainJobModelPath, task.JobName, setting.OutPutPath, versionName, parentDir, fileName), "/")
log.Info("Download path is:%s", path)
if setting.PROXYURL == "" {
body, err := storage.ObsDownloadAFile(setting.Bucket, path)
if err != nil {
log.Info("download error.")
} else {
//count++
// models.ModifyModelDownloadCount(id)
defer body.Close()
ctx.Resp.Header().Set("Content-Disposition", "attachment; filename="+fileName)
ctx.Resp.Header().Set("Content-Type", "application/octet-stream")
p := make([]byte, 1024)
var readErr error
var readCount int
// 读取对象内容
for {
readCount, readErr = body.Read(p)
if readCount > 0 {
ctx.Resp.Write(p[:readCount])
//fmt.Printf("%s", p[:readCount])
}
if readErr != nil {
break
}
}
}
} else {
url, err := storage.GetObsCreateSignedUrlByBucketAndKey(setting.Bucket, path)
if err != nil {
log.Error("GetObsCreateSignedUrl failed: %v", err.Error(), ctx.Data["msgID"])
ctx.ServerError("GetObsCreateSignedUrl", err)
return
}
//count++
// models.ModifyModelDownloadCount(id)
http.Redirect(ctx.Resp, ctx.Req.Request, url, http.StatusMovedPermanently)
}
}

+ 20
- 13
routers/repo/modelarts.go View File

@@ -1160,6 +1160,14 @@ func TrainJobShow(ctx *context.Context) {
ctx.RenderWithErr(err.Error(), tplModelArtsTrainJobShow, nil) ctx.RenderWithErr(err.Error(), tplModelArtsTrainJobShow, nil)
return return
} }
//设置权限
canNewJob, err := canUserCreateTrainJobVersion(ctx, VersionListTasks[0].UserID)
if err != nil {
ctx.ServerError("canNewJob failed", err)
return
}
ctx.Data["canNewJob"] = canNewJob

//将运行参数转化为epoch_size = 3, device_target = Ascend的格式 //将运行参数转化为epoch_size = 3, device_target = Ascend的格式
for i, _ := range VersionListTasks { for i, _ := range VersionListTasks {


@@ -1311,19 +1319,18 @@ func canUserCreateTrainJob(uid int64) (bool, error) {


return org.IsOrgMember(uid) return org.IsOrgMember(uid)
} }
func canUserCreateTrainJobVersion(ctx *context.Context, jobID string, versionName string) (bool, error) {
// task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, versionName)
// if err != nil {
// return false, err
// }
// if ctx.User.ID == task.UserID {
// canNewJob := true
// return canNewJob, nil
// } else {
// canNewJob := false
// return canNewJob, nil
// }
return true, nil
func canUserCreateTrainJobVersion(ctx *context.Context, userID int64) (bool, error) {
if ctx == nil || ctx.User == nil {
log.Error("user unlogin!")
return false, nil
}
if userID == ctx.User.ID || ctx.User.IsAdmin {
return true, nil
} else {
log.Error("Only user itself and admin can new trainjob!")
// ctx.ServerError("Only user itself and admin can new trainjob!", nil)
return false, nil
}
} }


func TrainJobGetConfigList(ctx *context.Context) { func TrainJobGetConfigList(ctx *context.Context) {


+ 8
- 8
templates/repo/modelarts/trainjob/show.tmpl View File

@@ -189,7 +189,7 @@ td, th {
<div style="float: right;"> <div style="float: right;">
<!-- <a class="ti-action-menu-item {{if ne .Status "COMPLETED"}}disabled {{end}}">创建模型</a> --> <!-- <a class="ti-action-menu-item {{if ne .Status "COMPLETED"}}disabled {{end}}">创建模型</a> -->
{{$.CsrfTokenHtml}} {{$.CsrfTokenHtml}}
{{if $.Permission.CanWrite $.UnitTypeCloudBrain}}
{{if $.canNewJob}}
<a class="ti-action-menu-item" href="{{$.RepoLink}}/modelarts/train-job/{{.JobID}}/create_version?version_name={{.VersionName}}">{{$.i18n.Tr "repo.modelarts.modify"}}</a> <a class="ti-action-menu-item" href="{{$.RepoLink}}/modelarts/train-job/{{.JobID}}/create_version?version_name={{.VersionName}}">{{$.i18n.Tr "repo.modelarts.modify"}}</a>
{{else}} {{else}}
<a class="ti-action-menu-item disabled" href="{{$.RepoLink}}/modelarts/train-job/{{.JobID}}/create_version?version_name={{.VersionName}}">{{$.i18n.Tr "repo.modelarts.modify"}}</a> <a class="ti-action-menu-item disabled" href="{{$.RepoLink}}/modelarts/train-job/{{.JobID}}/create_version?version_name={{.VersionName}}">{{$.i18n.Tr "repo.modelarts.modify"}}</a>
@@ -708,14 +708,14 @@ td, th {
html += "</div>" html += "</div>"
$(`#dir_list${version_name}`).append(html) $(`#dir_list${version_name}`).append(html)
} }
// $(`.log{}`).scroll()
function logScroll(version_name) { function logScroll(version_name) {
let scrollTop = $(`#log${version_name}`)[0].scrollTop; // 滚动距离
let scrollHeight = $(`#log${version_name}`)[0].scrollHeight; // 文档高度
let divHeight = $(`#log${version_name}`).height(); // 可视区高度
// let version_name=$(this).find('input[name=version_name]').val()
console.log("scrollTo,scrollHeight,divHeight",scrollTop,scrollHeight,divHeight)
if(parseInt(scrollTop) + divHeight -10 == scrollHeight){
let container = document.querySelector(`#log${version_name}`)
let scrollTop = container.scrollTop
let scrollHeight = container.scrollHeight
let clientHeight = container.clientHeight
if(parseInt(scrollTop) + clientHeight == scrollHeight && scrollHeight>clientHeight){
let end_line = $(`#log${version_name} input[name=end_line]`).val() let end_line = $(`#log${version_name} input[name=end_line]`).val()
$.get(`/api/v1/repos/${userName}/${repoPath}/modelarts/train-job/${jobID}/log?version_name=${version_name}&base_line=${end_line}&lines=50&order=desc`, (data) => { $.get(`/api/v1/repos/${userName}/${repoPath}/modelarts/train-job/${jobID}/log?version_name=${version_name}&base_line=${end_line}&lines=50&order=desc`, (data) => {
if (data.Lines == 0){ if (data.Lines == 0){


Loading…
Cancel
Save