diff --git a/modules/storage/obs.go b/modules/storage/obs.go index 9ca26b357..540463e80 100755 --- a/modules/storage/obs.go +++ b/modules/storage/obs.go @@ -176,10 +176,10 @@ func ObsModelDownload(JobName string, fileName string) (io.ReadCloser, error) { } } -func GetObsListObject(jobName, parentDir string, versionOutputPath string) ([]FileInfo, error) { +func GetObsListObject(jobName, parentDir, versionName string) ([]FileInfo, error) { input := &obs.ListObjectsInput{} input.Bucket = setting.Bucket - input.Prefix = strings.TrimPrefix(path.Join(setting.TrainJobModelPath, jobName, setting.OutPutPath, versionOutputPath, parentDir), "/") + input.Prefix = strings.TrimPrefix(path.Join(setting.TrainJobModelPath, jobName, setting.OutPutPath, versionName, parentDir), "/") strPrefix := strings.Split(input.Prefix, "/") output, err := ObsCli.ListObjects(input) fileInfos := make([]FileInfo, 0) @@ -275,8 +275,34 @@ func GetObsCreateSignedUrl(jobName, parentDir, fileName string) (string, error) log.Error("CreateSignedUrl failed:", err.Error()) return "", err } + log.Info("SignedUrl:%s", output.SignedUrl) + return output.SignedUrl, nil +} + +func GetObsCreateSignedUrlByBucketAndKey(bucket, key string) (string, error) { + input := &obs.CreateSignedUrlInput{} + input.Bucket = bucket + input.Key = key + + input.Expires = 60 * 60 + input.Method = obs.HttpMethodGet + comma := strings.LastIndex(key, "/") + filename := key + if comma != -1 { + filename = key[comma+1:] + } + reqParams := make(map[string]string) + filename = url.QueryEscape(filename) + reqParams["response-content-disposition"] = "attachment; filename=\"" + filename + "\"" + input.QueryParams = reqParams + output, err := ObsCli.CreateSignedUrl(input) + if err != nil { + log.Error("CreateSignedUrl failed:", err.Error()) + return "", err + } return output.SignedUrl, nil + } func ObsGetPreSignedUrl(uuid, fileName string) (string, error) { @@ -311,3 +337,20 @@ func ObsCreateObject(path string) error { return nil } + +func ObsDownloadAFile(bucket string, key string) (io.ReadCloser, error) { + input := &obs.GetObjectInput{} + input.Bucket = bucket + input.Key = key + output, err := ObsCli.GetObject(input) + if err == nil { + log.Info("StorageClass:%s, ETag:%s, ContentType:%s, ContentLength:%d, LastModified:%s\n", + output.StorageClass, output.ETag, output.ContentType, output.ContentLength, output.LastModified) + return output.Body, nil + } else if obsError, ok := err.(obs.ObsError); ok { + log.Error("Code:%s, Message:%s", obsError.Code, obsError.Message) + return nil, obsError + } else { + return nil, err + } +} diff --git a/routers/api/v1/repo/modelarts.go b/routers/api/v1/repo/modelarts.go index 9637bae99..80ff57a56 100755 --- a/routers/api/v1/repo/modelarts.go +++ b/routers/api/v1/repo/modelarts.go @@ -7,6 +7,7 @@ package repo import ( "net/http" + "path" "strconv" "strings" @@ -14,6 +15,7 @@ import ( "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/modelarts" + "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/storage" ) @@ -302,8 +304,7 @@ func ModelList(ctx *context.APIContext) { log.Error("GetCloudbrainByJobID(%s) failed:%v", task.JobName, err.Error()) return } - VersionOutputPath := modelarts.GetVersionOutputPathByTotalVersionCount(task.TotalVersionCount) - models, err := storage.GetObsListObject(task.JobName, parentDir, VersionOutputPath) + models, err := storage.GetObsListObject(task.JobName, parentDir, versionName) if err != nil { log.Info("get TrainJobListModel failed:", err) ctx.ServerError("GetObsListObject:", err) @@ -321,27 +322,61 @@ func ModelList(ctx *context.APIContext) { }) } -func ModelDownload(ctx *context.APIContext) { +func ModelDownload(ctx *context.Context) { var ( err error ) var jobID = ctx.Params(":jobid") versionName := ctx.Query("version_name") + // versionName := "V0001" parentDir := ctx.Query("parent_dir") fileName := ctx.Query("file_name") + log.Info("DownloadSingleModelFile start.") + // id := ctx.Params(":ID") + // path := Model_prefix + models.AttachmentRelativePath(id) + "/" + parentDir + fileName task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, versionName) if err != nil { log.Error("GetCloudbrainByJobID(%s) failed:%v", task.JobName, err.Error()) return } - VersionOutputPath := modelarts.GetVersionOutputPathByTotalVersionCount(task.TotalVersionCount) - parentDir = VersionOutputPath + "/" + parentDir - url, err := storage.GetObsCreateSignedUrl(task.JobName, parentDir, fileName) - if err != nil { - log.Error("GetObsCreateSignedUrl failed: %v", err.Error(), ctx.Data["msgID"]) - ctx.ServerError("GetObsCreateSignedUrl", err) - return + + path := strings.TrimPrefix(path.Join(setting.TrainJobModelPath, task.JobName, setting.OutPutPath, versionName, parentDir, fileName), "/") + log.Info("Download path is:%s", path) + if setting.PROXYURL != "" { + body, err := storage.ObsDownloadAFile(setting.Bucket, path) + if err != nil { + log.Info("download error.") + } else { + //count++ + // models.ModifyModelDownloadCount(id) + defer body.Close() + ctx.Resp.Header().Set("Content-Disposition", "attachment; filename="+fileName) + ctx.Resp.Header().Set("Content-Type", "application/octet-stream") + p := make([]byte, 1024) + var readErr error + var readCount int + // 读取对象内容 + for { + readCount, readErr = body.Read(p) + if readCount > 0 { + ctx.Resp.Write(p[:readCount]) + //fmt.Printf("%s", p[:readCount]) + } + if readErr != nil { + break + } + } + } + } else { + url, err := storage.GetObsCreateSignedUrlByBucketAndKey(setting.Bucket, path) + if err != nil { + log.Error("GetObsCreateSignedUrl failed: %v", err.Error(), ctx.Data["msgID"]) + ctx.ServerError("GetObsCreateSignedUrl", err) + return + } + //count++ + // models.ModifyModelDownloadCount(id) + http.Redirect(ctx.Resp, ctx.Req.Request, url, http.StatusMovedPermanently) } - http.Redirect(ctx.Resp, ctx.Req.Request, url, http.StatusMovedPermanently) } diff --git a/routers/repo/modelarts.go b/routers/repo/modelarts.go index e35cf5c6b..a3602e0e7 100755 --- a/routers/repo/modelarts.go +++ b/routers/repo/modelarts.go @@ -369,6 +369,7 @@ func trainJobNewDataPrepare(ctx *context.Context) error { } ctx.Data["Branches"] = Branches ctx.Data["BranchesCount"] = len(Branches) + ctx.Data["params"] = "" configList, err := getConfigList(modelarts.PerPage, 1, modelarts.SortByCreateTime, "desc", "", modelarts.ConfigTypeCustom) if err != nil { @@ -380,7 +381,92 @@ func trainJobNewDataPrepare(ctx *context.Context) error { return nil } +func ErrorNewDataPrepare(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) error { + ctx.Data["PageIsCloudBrain"] = true + + //can, err := canUserCreateTrainJob(ctx.User.ID) + //if err != nil { + // ctx.ServerError("canUserCreateTrainJob", err) + // return + //} + // + //if !can { + // log.Error("the user can not create train-job") + // ctx.ServerError("the user can not create train-job", fmt.Errorf("the user can not create train-job")) + // return + //} + + t := time.Now() + var jobName = cutString(ctx.User.Name, 5) + t.Format("2006010215") + strconv.Itoa(int(t.Unix()))[5:] + ctx.Data["job_name"] = jobName + + attachs, err := models.GetModelArtsUserAttachments(ctx.User.ID) + if err != nil { + ctx.ServerError("GetAllUserAttachments failed:", err) + return err + } + ctx.Data["attachments"] = attachs + + var resourcePools modelarts.ResourcePool + if err = json.Unmarshal([]byte(setting.ResourcePools), &resourcePools); err != nil { + ctx.ServerError("json.Unmarshal failed:", err) + return err + } + ctx.Data["resource_pools"] = resourcePools.Info + + var engines modelarts.Engine + if err = json.Unmarshal([]byte(setting.Engines), &engines); err != nil { + ctx.ServerError("json.Unmarshal failed:", err) + return err + } + ctx.Data["engines"] = engines.Info + + var versionInfos modelarts.VersionInfo + if err = json.Unmarshal([]byte(setting.EngineVersions), &versionInfos); err != nil { + ctx.ServerError("json.Unmarshal failed:", err) + return err + } + ctx.Data["engine_versions"] = versionInfos.Version + + var flavorInfos modelarts.Flavor + if err = json.Unmarshal([]byte(setting.TrainJobFLAVORINFOS), &flavorInfos); err != nil { + ctx.ServerError("json.Unmarshal failed:", err) + return err + } + ctx.Data["flavor_infos"] = flavorInfos.Info + + outputObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.OutputPath + ctx.Data["train_url"] = outputObsPath + + Branches, err := ctx.Repo.GitRepo.GetBranches() + if err != nil { + ctx.ServerError("GetBranches error:", err) + return err + } + ctx.Data["Branches"] = Branches + ctx.Data["BranchesCount"] = len(Branches) + + configList, err := getConfigList(modelarts.PerPage, 1, modelarts.SortByCreateTime, "desc", "", modelarts.ConfigTypeCustom) + if err != nil { + ctx.ServerError("getConfigList failed:", err) + return err + } + var Parameters modelarts.Parameters + if err = json.Unmarshal([]byte(form.Params), &Parameters); err != nil { + ctx.ServerError("json.Unmarshal failed:", err) + return err + } + ctx.Data["params"] = Parameters.Parameter + ctx.Data["config_list"] = configList.ParaConfigs + ctx.Data["bootFile"] = form.BootFile + ctx.Data["uuid"] = form.Attachment + ctx.Data["branch_name"] = form.BranchName + + return nil +} + func TrainJobNewVersion(ctx *context.Context) { + err := trainJobNewVersionDataPrepare(ctx) if err != nil { ctx.ServerError("get new train-job info failed", err) @@ -394,6 +480,12 @@ func trainJobNewVersionDataPrepare(ctx *context.Context) error { var jobID = ctx.Params(":jobid") // var versionName = ctx.Params(":version-name") var versionName = ctx.Query("version_name") + // canNewJob, err := canUserCreateTrainJobVersion(ctx, jobID) + // if err != nil { + // ctx.ServerError("get can info failed", err) + // return err + // } + // ctx.Data["canNewJob"] = canNewJob task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, versionName) if err != nil { @@ -455,7 +547,8 @@ func trainJobNewVersionDataPrepare(ctx *context.Context) error { ctx.ServerError("GetBranches error:", err) return err } - ctx.Data["branches"] = Branches + + ctx.Data["branch"] = Branches ctx.Data["branch_name"] = task.BranchName ctx.Data["description"] = task.Description ctx.Data["boot_file"] = task.BootFile @@ -477,7 +570,7 @@ func trainJobNewVersionDataPrepare(ctx *context.Context) error { return nil } -func ErrorDataPrepare(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) error { +func VersionErrorDataPrepare(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) error { ctx.Data["PageIsCloudBrain"] = true var jobID = ctx.Params(":jobid") // var versionName = ctx.Params(":version-name") @@ -580,7 +673,6 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) isSaveParam := form.IsSaveParam repo := ctx.Repo.Repository codeLocalPath := setting.JobPath + jobName + modelarts.CodePath - // codeObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.CodePath + VersionOutputPath + "/" codeObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.CodePath outputObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.OutputPath + VersionOutputPath + "/" logObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.LogPath + VersionOutputPath + "/" @@ -593,7 +685,7 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) if err := paramCheckCreateTrainJob(form); err != nil { log.Error("paramCheckCreateTrainJob failed:(%v)", err) - trainJobNewDataPrepare(ctx) + ErrorNewDataPrepare(ctx, form) ctx.RenderWithErr(err.Error(), tplModelArtsTrainJobNew, &form) return } @@ -792,7 +884,7 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ if err := paramCheckCreateTrainJob(form); err != nil { log.Error("paramCheckCreateTrainJob failed:(%v)", err) - ErrorDataPrepare(ctx, form) + VersionErrorDataPrepare(ctx, form) ctx.RenderWithErr(err.Error(), tplModelArtsTrainJobVersionNew, &form) return } @@ -815,7 +907,7 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ Branch: branch_name, }); err != nil { log.Error("创建任务失败,任务名称已存在!: %s (%v)", repo.FullName(), err) - ErrorDataPrepare(ctx, form) + VersionErrorDataPrepare(ctx, form) ctx.RenderWithErr("创建任务失败,任务名称已存在!", tplModelArtsTrainJobVersionNew, &form) return } @@ -823,14 +915,14 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ //todo: upload code (send to file_server todo this work?) if err := obsMkdir(setting.CodePathPrefix + jobName + modelarts.OutputPath + VersionOutputPath + "/"); err != nil { log.Error("Failed to obsMkdir_output: %s (%v)", repo.FullName(), err) - ErrorDataPrepare(ctx, form) + VersionErrorDataPrepare(ctx, form) ctx.RenderWithErr("Failed to obsMkdir_output", tplModelArtsTrainJobVersionNew, &form) return } if err := obsMkdir(setting.CodePathPrefix + jobName + modelarts.LogPath + VersionOutputPath + "/"); err != nil { log.Error("Failed to obsMkdir_log: %s (%v)", repo.FullName(), err) - ErrorDataPrepare(ctx, form) + VersionErrorDataPrepare(ctx, form) ctx.RenderWithErr("Failed to obsMkdir_log", tplModelArtsTrainJobVersionNew, &form) return } @@ -839,7 +931,7 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ // if err := uploadCodeToObs(codeLocalPath, jobName, ""); err != nil { if err := uploadCodeToObs(codeLocalPath, jobName, parentDir); err != nil { log.Error("Failed to uploadCodeToObs: %s (%v)", repo.FullName(), err) - ErrorDataPrepare(ctx, form) + VersionErrorDataPrepare(ctx, form) ctx.RenderWithErr("Failed to uploadCodeToObs", tplModelArtsTrainJobVersionNew, &form) return } @@ -859,7 +951,7 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ err := json.Unmarshal([]byte(params), ¶meters) if err != nil { log.Error("Failed to Unmarshal params: %s (%v)", params, err) - ErrorDataPrepare(ctx, form) + VersionErrorDataPrepare(ctx, form) ctx.RenderWithErr("运行参数错误", tplModelArtsTrainJobVersionNew, &form) return } @@ -878,7 +970,7 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ if isSaveParam == "on" { if form.ParameterTemplateName == "" { log.Error("ParameterTemplateName is empty") - ErrorDataPrepare(ctx, form) + VersionErrorDataPrepare(ctx, form) ctx.RenderWithErr("保存作业参数时,作业参数名称不能为空", tplModelArtsTrainJobVersionNew, &form) return } @@ -902,7 +994,7 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ if err != nil { log.Error("Failed to CreateTrainJobConfig: %v", err) - ErrorDataPrepare(ctx, form) + VersionErrorDataPrepare(ctx, form) ctx.RenderWithErr("保存作业参数失败:"+err.Error(), tplModelArtsTrainJobVersionNew, &form) return } @@ -949,7 +1041,7 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ err = modelarts.GenerateTrainJobVersion(ctx, req, jobID) if err != nil { log.Error("GenerateTrainJob failed:%v", err.Error()) - ErrorDataPrepare(ctx, form) + VersionErrorDataPrepare(ctx, form) ctx.RenderWithErr(err.Error(), tplModelArtsTrainJobVersionNew, &form) return } @@ -1220,6 +1312,18 @@ func canUserCreateTrainJob(uid int64) (bool, error) { return org.IsOrgMember(uid) } +func canUserCreateTrainJobVersion(ctx *context.Context, jobID string) (bool, error) { + + var versionName = "V0001" + task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, versionName) + if err != nil { + return false, err + } + if ctx.User.ID == task.User.ID { + return true, nil + } + return false, err +} func TrainJobGetConfigList(ctx *context.Context) { ctx.Data["PageIsTrainJob"] = true diff --git a/templates/repo/cloudbrain/index.tmpl b/templates/repo/cloudbrain/index.tmpl index d72ffb0c4..097758c8f 100755 --- a/templates/repo/cloudbrain/index.tmpl +++ b/templates/repo/cloudbrain/index.tmpl @@ -307,9 +307,9 @@
@@ -380,7 +380,7 @@ {{end}} -