From 74a95e054ac40cb18f1f6bce7d6656646052f81c Mon Sep 17 00:00:00 2001 From: liuzx Date: Thu, 25 Nov 2021 11:06:26 +0800 Subject: [PATCH] fix bug --- routers/api/v1/repo/modelarts.go | 2 +- routers/repo/modelarts.go | 87 ++++++++++++++++++++++++++++++++++------ routers/routes/routes.go | 1 + 3 files changed, 76 insertions(+), 14 deletions(-) diff --git a/routers/api/v1/repo/modelarts.go b/routers/api/v1/repo/modelarts.go index 80ff57a56..1e6b17ad8 100755 --- a/routers/api/v1/repo/modelarts.go +++ b/routers/api/v1/repo/modelarts.go @@ -343,7 +343,7 @@ func ModelDownload(ctx *context.Context) { path := strings.TrimPrefix(path.Join(setting.TrainJobModelPath, task.JobName, setting.OutPutPath, versionName, parentDir, fileName), "/") log.Info("Download path is:%s", path) - if setting.PROXYURL != "" { + if setting.PROXYURL == "" { body, err := storage.ObsDownloadAFile(setting.Bucket, path) if err != nil { log.Info("download error.") diff --git a/routers/repo/modelarts.go b/routers/repo/modelarts.go index a3602e0e7..5e5b0a3e0 100755 --- a/routers/repo/modelarts.go +++ b/routers/repo/modelarts.go @@ -480,9 +480,9 @@ func trainJobNewVersionDataPrepare(ctx *context.Context) error { var jobID = ctx.Params(":jobid") // var versionName = ctx.Params(":version-name") var versionName = ctx.Query("version_name") - // canNewJob, err := canUserCreateTrainJobVersion(ctx, jobID) + // canNewJob, err := canUserCreateTrainJobVersion(ctx, jobID, versionName) // if err != nil { - // ctx.ServerError("get can info failed", err) + // ctx.ServerError("canNewJob can info failed", err) // return err // } // ctx.Data["canNewJob"] = canNewJob @@ -1312,17 +1312,19 @@ func canUserCreateTrainJob(uid int64) (bool, error) { return org.IsOrgMember(uid) } -func canUserCreateTrainJobVersion(ctx *context.Context, jobID string) (bool, error) { - - var versionName = "V0001" - task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, versionName) - if err != nil { - return false, err - } - if ctx.User.ID == task.User.ID { - return true, nil - } - return false, err +func canUserCreateTrainJobVersion(ctx *context.Context, jobID string, versionName string) (bool, error) { + // task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, versionName) + // if err != nil { + // return false, err + // } + // if ctx.User.ID == task.UserID { + // canNewJob := true + // return canNewJob, nil + // } else { + // canNewJob := false + // return canNewJob, nil + // } + return true, nil } func TrainJobGetConfigList(ctx *context.Context) { @@ -1378,3 +1380,62 @@ func getConfigList(perPage, page int, sortBy, order, searchContent, configType s return list, nil } + +func ModelDownload(ctx *context.Context) { + var ( + err error + ) + + var jobID = ctx.Params(":jobid") + versionName := ctx.Query("version_name") + // versionName := "V0001" + parentDir := ctx.Query("parent_dir") + fileName := ctx.Query("file_name") + log.Info("DownloadSingleModelFile start.") + // id := ctx.Params(":ID") + // path := Model_prefix + models.AttachmentRelativePath(id) + "/" + parentDir + fileName + task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, versionName) + if err != nil { + log.Error("GetCloudbrainByJobID(%s) failed:%v", task.JobName, err.Error()) + return + } + + path := strings.TrimPrefix(path.Join(setting.TrainJobModelPath, task.JobName, setting.OutPutPath, versionName, parentDir, fileName), "/") + log.Info("Download path is:%s", path) + if setting.PROXYURL == "" { + body, err := storage.ObsDownloadAFile(setting.Bucket, path) + if err != nil { + log.Info("download error.") + } else { + //count++ + // models.ModifyModelDownloadCount(id) + defer body.Close() + ctx.Resp.Header().Set("Content-Disposition", "attachment; filename="+fileName) + ctx.Resp.Header().Set("Content-Type", "application/octet-stream") + p := make([]byte, 1024) + var readErr error + var readCount int + // 读取对象内容 + for { + readCount, readErr = body.Read(p) + if readCount > 0 { + ctx.Resp.Write(p[:readCount]) + //fmt.Printf("%s", p[:readCount]) + } + if readErr != nil { + break + } + } + } + } else { + url, err := storage.GetObsCreateSignedUrlByBucketAndKey(setting.Bucket, path) + if err != nil { + log.Error("GetObsCreateSignedUrl failed: %v", err.Error(), ctx.Data["msgID"]) + ctx.ServerError("GetObsCreateSignedUrl", err) + return + } + //count++ + // models.ModifyModelDownloadCount(id) + http.Redirect(ctx.Resp, ctx.Req.Request, url, http.StatusMovedPermanently) + } +} diff --git a/routers/routes/routes.go b/routers/routes/routes.go index d3212690c..3bacb7549 100755 --- a/routers/routes/routes.go +++ b/routers/routes/routes.go @@ -988,6 +988,7 @@ func RegisterRoutes(m *macaron.Macaron) { m.Get("", reqRepoCloudBrainReader, repo.TrainJobShow) m.Post("/stop", reqRepoCloudBrainWriter, repo.TrainJobStop) m.Post("/del", reqRepoCloudBrainWriter, repo.TrainJobDel) + m.Get("/model_download", reqRepoCloudBrainReader, repo.ModelDownload) m.Get("/create_version", reqRepoCloudBrainReader, repo.TrainJobNewVersion) m.Post("/create_version", reqRepoCloudBrainWriter, bindIgnErr(auth.CreateModelArtsTrainJobForm{}), repo.TrainJobCreateVersion) })