Reviewed-on: https://git.openi.org.cn/OpenI/aiforge/pulls/962 Reviewed-by: zhoupzh <zhoupzh@pcl.ac.cn>pull/963/head
@@ -337,20 +337,3 @@ func ObsCreateObject(path string) error { | |||||
return nil | return nil | ||||
} | } | ||||
func ObsDownloadAFile(bucket string, key string) (io.ReadCloser, error) { | |||||
input := &obs.GetObjectInput{} | |||||
input.Bucket = bucket | |||||
input.Key = key | |||||
output, err := ObsCli.GetObject(input) | |||||
if err == nil { | |||||
log.Info("StorageClass:%s, ETag:%s, ContentType:%s, ContentLength:%d, LastModified:%s\n", | |||||
output.StorageClass, output.ETag, output.ContentType, output.ContentLength, output.LastModified) | |||||
return output.Body, nil | |||||
} else if obsError, ok := err.(obs.ObsError); ok { | |||||
log.Error("Code:%s, Message:%s", obsError.Code, obsError.Message) | |||||
return nil, obsError | |||||
} else { | |||||
return nil, err | |||||
} | |||||
} |
@@ -880,7 +880,6 @@ func RegisterRoutes(m *macaron.Macaron) { | |||||
m.Post("/del_version", repo.DelTrainJobVersion) | m.Post("/del_version", repo.DelTrainJobVersion) | ||||
m.Post("/stop_version", repo.StopTrainJobVersion) | m.Post("/stop_version", repo.StopTrainJobVersion) | ||||
m.Get("/model_list", repo.ModelList) | m.Get("/model_list", repo.ModelList) | ||||
m.Get("/model_download", repo.ModelDownload) | |||||
}) | }) | ||||
}) | }) | ||||
}, reqRepoReader(models.UnitTypeCloudBrain)) | }, reqRepoReader(models.UnitTypeCloudBrain)) | ||||
@@ -7,7 +7,6 @@ package repo | |||||
import ( | import ( | ||||
"net/http" | "net/http" | ||||
"path" | |||||
"strconv" | "strconv" | ||||
"strings" | "strings" | ||||
@@ -15,7 +14,6 @@ import ( | |||||
"code.gitea.io/gitea/modules/context" | "code.gitea.io/gitea/modules/context" | ||||
"code.gitea.io/gitea/modules/log" | "code.gitea.io/gitea/modules/log" | ||||
"code.gitea.io/gitea/modules/modelarts" | "code.gitea.io/gitea/modules/modelarts" | ||||
"code.gitea.io/gitea/modules/setting" | |||||
"code.gitea.io/gitea/modules/storage" | "code.gitea.io/gitea/modules/storage" | ||||
) | ) | ||||
@@ -321,62 +319,3 @@ func ModelList(ctx *context.APIContext) { | |||||
"PageIsCloudBrain": true, | "PageIsCloudBrain": true, | ||||
}) | }) | ||||
} | } | ||||
func ModelDownload(ctx *context.Context) { | |||||
var ( | |||||
err error | |||||
) | |||||
var jobID = ctx.Params(":jobid") | |||||
versionName := ctx.Query("version_name") | |||||
// versionName := "V0001" | |||||
parentDir := ctx.Query("parent_dir") | |||||
fileName := ctx.Query("file_name") | |||||
log.Info("DownloadSingleModelFile start.") | |||||
// id := ctx.Params(":ID") | |||||
// path := Model_prefix + models.AttachmentRelativePath(id) + "/" + parentDir + fileName | |||||
task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, versionName) | |||||
if err != nil { | |||||
log.Error("GetCloudbrainByJobID(%s) failed:%v", task.JobName, err.Error()) | |||||
return | |||||
} | |||||
path := strings.TrimPrefix(path.Join(setting.TrainJobModelPath, task.JobName, setting.OutPutPath, versionName, parentDir, fileName), "/") | |||||
log.Info("Download path is:%s", path) | |||||
if setting.PROXYURL == "" { | |||||
body, err := storage.ObsDownloadAFile(setting.Bucket, path) | |||||
if err != nil { | |||||
log.Info("download error.") | |||||
} else { | |||||
//count++ | |||||
// models.ModifyModelDownloadCount(id) | |||||
defer body.Close() | |||||
ctx.Resp.Header().Set("Content-Disposition", "attachment; filename="+fileName) | |||||
ctx.Resp.Header().Set("Content-Type", "application/octet-stream") | |||||
p := make([]byte, 1024) | |||||
var readErr error | |||||
var readCount int | |||||
// 读取对象内容 | |||||
for { | |||||
readCount, readErr = body.Read(p) | |||||
if readCount > 0 { | |||||
ctx.Resp.Write(p[:readCount]) | |||||
//fmt.Printf("%s", p[:readCount]) | |||||
} | |||||
if readErr != nil { | |||||
break | |||||
} | |||||
} | |||||
} | |||||
} else { | |||||
url, err := storage.GetObsCreateSignedUrlByBucketAndKey(setting.Bucket, path) | |||||
if err != nil { | |||||
log.Error("GetObsCreateSignedUrl failed: %v", err.Error(), ctx.Data["msgID"]) | |||||
ctx.ServerError("GetObsCreateSignedUrl", err) | |||||
return | |||||
} | |||||
//count++ | |||||
// models.ModifyModelDownloadCount(id) | |||||
http.Redirect(ctx.Resp, ctx.Req.Request, url, http.StatusMovedPermanently) | |||||
} | |||||
} |
@@ -1160,6 +1160,14 @@ func TrainJobShow(ctx *context.Context) { | |||||
ctx.RenderWithErr(err.Error(), tplModelArtsTrainJobShow, nil) | ctx.RenderWithErr(err.Error(), tplModelArtsTrainJobShow, nil) | ||||
return | return | ||||
} | } | ||||
//设置权限 | |||||
canNewJob, err := canUserCreateTrainJobVersion(ctx, VersionListTasks[0].UserID) | |||||
if err != nil { | |||||
ctx.ServerError("canNewJob failed", err) | |||||
return | |||||
} | |||||
ctx.Data["canNewJob"] = canNewJob | |||||
//将运行参数转化为epoch_size = 3, device_target = Ascend的格式 | //将运行参数转化为epoch_size = 3, device_target = Ascend的格式 | ||||
for i, _ := range VersionListTasks { | for i, _ := range VersionListTasks { | ||||
@@ -1311,19 +1319,18 @@ func canUserCreateTrainJob(uid int64) (bool, error) { | |||||
return org.IsOrgMember(uid) | return org.IsOrgMember(uid) | ||||
} | } | ||||
func canUserCreateTrainJobVersion(ctx *context.Context, jobID string, versionName string) (bool, error) { | |||||
// task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, versionName) | |||||
// if err != nil { | |||||
// return false, err | |||||
// } | |||||
// if ctx.User.ID == task.UserID { | |||||
// canNewJob := true | |||||
// return canNewJob, nil | |||||
// } else { | |||||
// canNewJob := false | |||||
// return canNewJob, nil | |||||
// } | |||||
return true, nil | |||||
func canUserCreateTrainJobVersion(ctx *context.Context, userID int64) (bool, error) { | |||||
if ctx == nil || ctx.User == nil { | |||||
log.Error("user unlogin!") | |||||
return false, nil | |||||
} | |||||
if userID == ctx.User.ID || ctx.User.IsAdmin { | |||||
return true, nil | |||||
} else { | |||||
log.Error("Only user itself and admin can new trainjob!") | |||||
// ctx.ServerError("Only user itself and admin can new trainjob!", nil) | |||||
return false, nil | |||||
} | |||||
} | } | ||||
func TrainJobGetConfigList(ctx *context.Context) { | func TrainJobGetConfigList(ctx *context.Context) { | ||||
@@ -189,7 +189,7 @@ td, th { | |||||
<div style="float: right;"> | <div style="float: right;"> | ||||
<!-- <a class="ti-action-menu-item {{if ne .Status "COMPLETED"}}disabled {{end}}">创建模型</a> --> | <!-- <a class="ti-action-menu-item {{if ne .Status "COMPLETED"}}disabled {{end}}">创建模型</a> --> | ||||
{{$.CsrfTokenHtml}} | {{$.CsrfTokenHtml}} | ||||
{{if $.Permission.CanWrite $.UnitTypeCloudBrain}} | |||||
{{if $.canNewJob}} | |||||
<a class="ti-action-menu-item" href="{{$.RepoLink}}/modelarts/train-job/{{.JobID}}/create_version?version_name={{.VersionName}}">{{$.i18n.Tr "repo.modelarts.modify"}}</a> | <a class="ti-action-menu-item" href="{{$.RepoLink}}/modelarts/train-job/{{.JobID}}/create_version?version_name={{.VersionName}}">{{$.i18n.Tr "repo.modelarts.modify"}}</a> | ||||
{{else}} | {{else}} | ||||
<a class="ti-action-menu-item disabled" href="{{$.RepoLink}}/modelarts/train-job/{{.JobID}}/create_version?version_name={{.VersionName}}">{{$.i18n.Tr "repo.modelarts.modify"}}</a> | <a class="ti-action-menu-item disabled" href="{{$.RepoLink}}/modelarts/train-job/{{.JobID}}/create_version?version_name={{.VersionName}}">{{$.i18n.Tr "repo.modelarts.modify"}}</a> | ||||
@@ -708,14 +708,14 @@ td, th { | |||||
html += "</div>" | html += "</div>" | ||||
$(`#dir_list${version_name}`).append(html) | $(`#dir_list${version_name}`).append(html) | ||||
} | } | ||||
// $(`.log{}`).scroll() | |||||
function logScroll(version_name) { | function logScroll(version_name) { | ||||
let scrollTop = $(`#log${version_name}`)[0].scrollTop; // 滚动距离 | |||||
let scrollHeight = $(`#log${version_name}`)[0].scrollHeight; // 文档高度 | |||||
let divHeight = $(`#log${version_name}`).height(); // 可视区高度 | |||||
// let version_name=$(this).find('input[name=version_name]').val() | |||||
console.log("scrollTo,scrollHeight,divHeight",scrollTop,scrollHeight,divHeight) | |||||
if(parseInt(scrollTop) + divHeight -10 == scrollHeight){ | |||||
let container = document.querySelector(`#log${version_name}`) | |||||
let scrollTop = container.scrollTop | |||||
let scrollHeight = container.scrollHeight | |||||
let clientHeight = container.clientHeight | |||||
if(parseInt(scrollTop) + clientHeight == scrollHeight && scrollHeight>clientHeight){ | |||||
let end_line = $(`#log${version_name} input[name=end_line]`).val() | let end_line = $(`#log${version_name} input[name=end_line]`).val() | ||||
$.get(`/api/v1/repos/${userName}/${repoPath}/modelarts/train-job/${jobID}/log?version_name=${version_name}&base_line=${end_line}&lines=50&order=desc`, (data) => { | $.get(`/api/v1/repos/${userName}/${repoPath}/modelarts/train-job/${jobID}/log?version_name=${version_name}&base_line=${end_line}&lines=50&order=desc`, (data) => { | ||||
if (data.Lines == 0){ | if (data.Lines == 0){ | ||||