diff --git a/modules/grampus/grampus.go b/modules/grampus/grampus.go index 5f3bd07dd..c55fd58ea 100755 --- a/modules/grampus/grampus.go +++ b/modules/grampus/grampus.go @@ -109,6 +109,7 @@ type GenerateNotebookJobReq struct { Spec *models.Specification CodeName string ModelPath string //参考启智GPU调试, 挂载/model目录用户的模型可以输出到这个目录 + ModelStorageType int } func getEndPoint() string { @@ -215,14 +216,25 @@ func GenerateNotebookJob(ctx *context.Context, req *GenerateNotebookJobReq) (job datasetGrampus, cpCommand = getDatasetGPUGrampus(req.DatasetInfos) } if len(req.ModelName) != 0 { - datasetGrampus = append(datasetGrampus, models.GrampusDataset{ - Name: req.ModelName, - Bucket: setting.Attachment.Minio.Bucket, - EndPoint: setting.Attachment.Minio.Endpoint, - ObjectKey: req.PreTrainModelPath, - ReadOnly: true, - ContainerPath: cloudbrain.PretrainModelMountPath, - }) + if req.ModelStorageType == models.TypeCloudBrainOne { + datasetGrampus = append(datasetGrampus, models.GrampusDataset{ + Name: req.ModelName, + Bucket: setting.Attachment.Minio.Bucket, + EndPoint: setting.Attachment.Minio.Endpoint, + ObjectKey: req.PreTrainModelPath, + ReadOnly: true, + ContainerPath: cloudbrain.PretrainModelMountPath, + }) + } else { + datasetGrampus = append(datasetGrampus, models.GrampusDataset{ + Name: req.ModelName, + Bucket: setting.Bucket, + EndPoint: getEndPoint(), + ReadOnly: true, + ObjectKey: req.PreTrainModelPath, + }) + } + } codeArchiveName := cloudbrain.DefaultBranchName + ".zip" codeGrampus = models.GrampusDataset{ diff --git a/routers/repo/grampus.go b/routers/repo/grampus.go index 673960fa1..5264f3a43 100755 --- a/routers/repo/grampus.go +++ b/routers/repo/grampus.go @@ -267,7 +267,7 @@ func GrampusNotebookCreate(ctx *context.Context, form auth.CreateGrampusNotebook if form.ModelName != "" { //使用预训练模型训练 - _, err := models.QueryModelByPath(form.PreTrainModelUrl) + m, err := models.QueryModelByPath(form.PreTrainModelUrl) if err != nil { log.Error("Can not find model", err) grampusNotebookNewDataPrepare(ctx, processType) @@ -280,6 +280,7 @@ func GrampusNotebookCreate(ctx *context.Context, form auth.CreateGrampusNotebook req.ModelVersion = form.ModelVersion req.PreTrainModelUrl = form.PreTrainModelUrl req.PreTrainModelPath = getPreTrainModelPath(form.PreTrainModelUrl, form.CkptName) + req.ModelStorageType = m.Type } _, err = grampus.GenerateNotebookJob(ctx, req)