diff --git a/modules/storage/obs.go b/modules/storage/obs.go index 540463e80..a5c463bb0 100755 --- a/modules/storage/obs.go +++ b/modules/storage/obs.go @@ -337,20 +337,3 @@ func ObsCreateObject(path string) error { return nil } - -func ObsDownloadAFile(bucket string, key string) (io.ReadCloser, error) { - input := &obs.GetObjectInput{} - input.Bucket = bucket - input.Key = key - output, err := ObsCli.GetObject(input) - if err == nil { - log.Info("StorageClass:%s, ETag:%s, ContentType:%s, ContentLength:%d, LastModified:%s\n", - output.StorageClass, output.ETag, output.ContentType, output.ContentLength, output.LastModified) - return output.Body, nil - } else if obsError, ok := err.(obs.ObsError); ok { - log.Error("Code:%s, Message:%s", obsError.Code, obsError.Message) - return nil, obsError - } else { - return nil, err - } -} diff --git a/routers/api/v1/api.go b/routers/api/v1/api.go index 93afe2a31..b7ef8d48f 100755 --- a/routers/api/v1/api.go +++ b/routers/api/v1/api.go @@ -880,7 +880,6 @@ func RegisterRoutes(m *macaron.Macaron) { m.Post("/del_version", repo.DelTrainJobVersion) m.Post("/stop_version", repo.StopTrainJobVersion) m.Get("/model_list", repo.ModelList) - m.Get("/model_download", repo.ModelDownload) }) }) }, reqRepoReader(models.UnitTypeCloudBrain)) diff --git a/routers/api/v1/repo/modelarts.go b/routers/api/v1/repo/modelarts.go index 1e6b17ad8..b4d5fc010 100755 --- a/routers/api/v1/repo/modelarts.go +++ b/routers/api/v1/repo/modelarts.go @@ -7,7 +7,6 @@ package repo import ( "net/http" - "path" "strconv" "strings" @@ -15,7 +14,6 @@ import ( "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/modelarts" - "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/storage" ) @@ -321,62 +319,3 @@ func ModelList(ctx *context.APIContext) { "PageIsCloudBrain": true, }) } - -func ModelDownload(ctx *context.Context) { - var ( - err error - ) - - var jobID = ctx.Params(":jobid") - versionName := ctx.Query("version_name") - // versionName := "V0001" - parentDir := ctx.Query("parent_dir") - fileName := ctx.Query("file_name") - log.Info("DownloadSingleModelFile start.") - // id := ctx.Params(":ID") - // path := Model_prefix + models.AttachmentRelativePath(id) + "/" + parentDir + fileName - task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, versionName) - if err != nil { - log.Error("GetCloudbrainByJobID(%s) failed:%v", task.JobName, err.Error()) - return - } - - path := strings.TrimPrefix(path.Join(setting.TrainJobModelPath, task.JobName, setting.OutPutPath, versionName, parentDir, fileName), "/") - log.Info("Download path is:%s", path) - if setting.PROXYURL == "" { - body, err := storage.ObsDownloadAFile(setting.Bucket, path) - if err != nil { - log.Info("download error.") - } else { - //count++ - // models.ModifyModelDownloadCount(id) - defer body.Close() - ctx.Resp.Header().Set("Content-Disposition", "attachment; filename="+fileName) - ctx.Resp.Header().Set("Content-Type", "application/octet-stream") - p := make([]byte, 1024) - var readErr error - var readCount int - // 读取对象内容 - for { - readCount, readErr = body.Read(p) - if readCount > 0 { - ctx.Resp.Write(p[:readCount]) - //fmt.Printf("%s", p[:readCount]) - } - if readErr != nil { - break - } - } - } - } else { - url, err := storage.GetObsCreateSignedUrlByBucketAndKey(setting.Bucket, path) - if err != nil { - log.Error("GetObsCreateSignedUrl failed: %v", err.Error(), ctx.Data["msgID"]) - ctx.ServerError("GetObsCreateSignedUrl", err) - return - } - //count++ - // models.ModifyModelDownloadCount(id) - http.Redirect(ctx.Resp, ctx.Req.Request, url, http.StatusMovedPermanently) - } -} diff --git a/routers/repo/modelarts.go b/routers/repo/modelarts.go index a820c56fa..e048b98f4 100755 --- a/routers/repo/modelarts.go +++ b/routers/repo/modelarts.go @@ -1160,6 +1160,14 @@ func TrainJobShow(ctx *context.Context) { ctx.RenderWithErr(err.Error(), tplModelArtsTrainJobShow, nil) return } + //设置权限 + canNewJob, err := canUserCreateTrainJobVersion(ctx, VersionListTasks[0].UserID) + if err != nil { + ctx.ServerError("canNewJob failed", err) + return + } + ctx.Data["canNewJob"] = canNewJob + //将运行参数转化为epoch_size = 3, device_target = Ascend的格式 for i, _ := range VersionListTasks { @@ -1311,19 +1319,18 @@ func canUserCreateTrainJob(uid int64) (bool, error) { return org.IsOrgMember(uid) } -func canUserCreateTrainJobVersion(ctx *context.Context, jobID string, versionName string) (bool, error) { - // task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, versionName) - // if err != nil { - // return false, err - // } - // if ctx.User.ID == task.UserID { - // canNewJob := true - // return canNewJob, nil - // } else { - // canNewJob := false - // return canNewJob, nil - // } - return true, nil +func canUserCreateTrainJobVersion(ctx *context.Context, userID int64) (bool, error) { + if ctx == nil || ctx.User == nil { + log.Error("user unlogin!") + return false, nil + } + if userID == ctx.User.ID || ctx.User.IsAdmin { + return true, nil + } else { + log.Error("Only user itself and admin can new trainjob!") + // ctx.ServerError("Only user itself and admin can new trainjob!", nil) + return false, nil + } } func TrainJobGetConfigList(ctx *context.Context) { diff --git a/templates/repo/modelarts/trainjob/show.tmpl b/templates/repo/modelarts/trainjob/show.tmpl index 7d32d1b9c..4c9f38bb0 100755 --- a/templates/repo/modelarts/trainjob/show.tmpl +++ b/templates/repo/modelarts/trainjob/show.tmpl @@ -189,7 +189,7 @@ td, th {