You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grampus.go 8.7 kB

3 years ago
3 years ago
3 years ago
2 years ago
3 years ago
2 years ago
3 years ago
3 years ago
3 years ago
3 years ago
2 years ago
2 years ago
2 years ago
3 years ago
2 years ago
2 years ago
2 years ago
3 years ago
2 years ago
3 years ago
2 years ago
2 years ago
3 years ago
3 years ago
3 years ago
3 years ago
2 years ago
2 years ago
2 years ago
3 years ago
2 years ago
2 years ago
3 years ago
3 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
3 years ago
3 years ago
2 years ago
3 years ago
2 years ago
2 years ago
3 years ago
3 years ago
2 years ago
3 years ago
3 years ago
3 years ago
2 years ago
3 years ago
2 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
2 years ago
3 years ago
3 years ago
2 years ago
2 years ago
2 years ago
2 years ago
3 years ago
3 years ago
3 years ago
3 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. package grampus
  2. import (
  3. "encoding/json"
  4. "strings"
  5. "code.gitea.io/gitea/models"
  6. "code.gitea.io/gitea/modules/cloudbrain"
  7. "code.gitea.io/gitea/modules/context"
  8. "code.gitea.io/gitea/modules/log"
  9. "code.gitea.io/gitea/modules/notification"
  10. "code.gitea.io/gitea/modules/setting"
  11. "code.gitea.io/gitea/modules/timeutil"
  12. )
  13. const (
  14. JobPath = "job/"
  15. ProcessorTypeNPU = "npu.huawei.com/NPU"
  16. ProcessorTypeGPU = "nvidia.com/gpu"
  17. GpuWorkDir = "/tmp/"
  18. NpuWorkDir = "/cache/"
  19. NpuLocalLogUrl = "/tmp/train.log"
  20. CommandPrepareScriptNpu = ";mkdir -p output;mkdir -p code;mkdir -p dataset;mkdir -p pretrainmodel;"
  21. CodeArchiveName = "master.zip"
  22. BucketRemote = "grampus"
  23. RemoteModelPath = "/output/" + models.ModelSuffix
  24. )
  25. var (
  26. poolInfos *models.PoolInfos
  27. FlavorInfos *setting.StFlavorInfos
  28. ImageInfos *setting.StImageInfosModelArts
  29. SpecialPools *models.SpecialPools
  30. CommandPrepareScriptGpu = ";mkdir -p output;mkdir -p code;mkdir -p dataset;mkdir -p pretrainmodel;echo \"start loading script\";wget -q https://openi.pcl.ac.cn/OpenIOSSG/%s/archive/master.zip;" +
  31. "echo \"finish loading script\";unzip -q master.zip;cd %s;chmod 777 downloader_for_obs uploader_for_npu downloader_for_minio uploader_for_gpu;"
  32. )
  33. type GenerateTrainJobReq struct {
  34. JobName string
  35. Command string
  36. ImageUrl string //与image_id二选一,都有的情况下优先image_url
  37. ImageId string
  38. DisplayJobName string
  39. Uuid string
  40. Description string
  41. CodeObsPath string
  42. BootFile string
  43. BootFileUrl string
  44. DataUrl string
  45. TrainUrl string
  46. WorkServerNumber int
  47. EngineID int64
  48. CommitID string
  49. IsLatestVersion string
  50. BranchName string
  51. PreVersionId int64
  52. PreVersionName string
  53. VersionCount int
  54. EngineName string
  55. TotalVersionCount int
  56. ComputeResource string
  57. ProcessType string
  58. DatasetNames string
  59. DatasetInfos map[string]models.DatasetInfo
  60. Params string
  61. ModelName string
  62. LabelName string
  63. CkptName string
  64. ModelVersion string
  65. PreTrainModelPath string
  66. PreTrainModelUrl string
  67. Spec *models.Specification
  68. CodeName string
  69. }
  70. func getEndPoint() string {
  71. index := strings.Index(setting.Endpoint, "//")
  72. endpoint := setting.Endpoint[index+2:]
  73. return endpoint
  74. }
  75. func getDatasetGrampus(datasetInfos map[string]models.DatasetInfo) []models.GrampusDataset {
  76. var datasetGrampus []models.GrampusDataset
  77. endPoint := getEndPoint()
  78. for _, datasetInfo := range datasetInfos {
  79. datasetGrampus = append(datasetGrampus, models.GrampusDataset{
  80. Name: datasetInfo.FullName,
  81. Bucket: setting.Bucket,
  82. EndPoint: endPoint,
  83. ObjectKey: datasetInfo.DataLocalPath + datasetInfo.FullName,
  84. })
  85. }
  86. return datasetGrampus
  87. }
  88. func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (jobId string, err error) {
  89. createTime := timeutil.TimeStampNow()
  90. var datasetGrampus, modelGrampus []models.GrampusDataset
  91. var codeGrampus models.GrampusDataset
  92. if ProcessorTypeNPU == req.ProcessType {
  93. datasetGrampus = getDatasetGrampus(req.DatasetInfos)
  94. if len(req.ModelName) != 0 {
  95. modelGrampus = []models.GrampusDataset{
  96. {
  97. Name: req.ModelName,
  98. Bucket: setting.Bucket,
  99. EndPoint: getEndPoint(),
  100. ObjectKey: req.PreTrainModelPath,
  101. },
  102. }
  103. }
  104. codeGrampus = models.GrampusDataset{
  105. Name: req.CodeName,
  106. Bucket: setting.Bucket,
  107. EndPoint: getEndPoint(),
  108. ObjectKey: req.CodeObsPath + cloudbrain.DefaultBranchName + ".zip",
  109. }
  110. }
  111. jobResult, err := createJob(models.CreateGrampusJobRequest{
  112. Name: req.JobName,
  113. Tasks: []models.GrampusTasks{
  114. {
  115. Name: req.JobName,
  116. Command: req.Command,
  117. ResourceSpecId: req.Spec.SourceSpecId,
  118. ImageId: req.ImageId,
  119. ImageUrl: req.ImageUrl,
  120. CenterID: req.Spec.GetAvailableCenterIds(ctx.User.ID),
  121. ReplicaNum: 1,
  122. Datasets: datasetGrampus,
  123. Models: modelGrampus,
  124. Code: codeGrampus,
  125. BootFile: req.BootFile,
  126. },
  127. },
  128. })
  129. if err != nil {
  130. log.Error("createJob failed: %v", err.Error())
  131. return "", err
  132. }
  133. jobID := jobResult.JobInfo.JobID
  134. err = models.CreateCloudbrain(&models.Cloudbrain{
  135. Status: TransTrainJobStatus(jobResult.JobInfo.Status),
  136. UserID: ctx.User.ID,
  137. RepoID: ctx.Repo.Repository.ID,
  138. JobID: jobID,
  139. JobName: req.JobName,
  140. DisplayJobName: req.DisplayJobName,
  141. JobType: string(models.JobTypeTrain),
  142. Type: models.TypeC2Net,
  143. Uuid: req.Uuid,
  144. DatasetName: req.DatasetNames,
  145. CommitID: req.CommitID,
  146. IsLatestVersion: req.IsLatestVersion,
  147. ComputeResource: req.ComputeResource,
  148. ImageID: req.ImageId,
  149. TrainUrl: req.TrainUrl,
  150. BranchName: req.BranchName,
  151. Parameters: req.Params,
  152. BootFile: req.BootFile,
  153. DataUrl: req.DataUrl,
  154. Description: req.Description,
  155. WorkServerNumber: req.WorkServerNumber,
  156. EngineName: req.EngineName,
  157. VersionCount: req.VersionCount,
  158. TotalVersionCount: req.TotalVersionCount,
  159. CreatedUnix: createTime,
  160. UpdatedUnix: createTime,
  161. Spec: req.Spec,
  162. ModelName: req.ModelName,
  163. ModelVersion: req.ModelVersion,
  164. LabelName: req.LabelName,
  165. PreTrainModelUrl: req.PreTrainModelUrl,
  166. CkptName: req.CkptName,
  167. })
  168. if err != nil {
  169. log.Error("CreateCloudbrain(%s) failed:%v", req.DisplayJobName, err.Error())
  170. return "", err
  171. }
  172. var actionType models.ActionType
  173. if req.ComputeResource == models.NPUResource {
  174. actionType = models.ActionCreateGrampusNPUTrainTask
  175. } else if req.ComputeResource == models.GPUResource {
  176. actionType = models.ActionCreateGrampusGPUTrainTask
  177. }
  178. notification.NotifyOtherTask(ctx.User, ctx.Repo.Repository, jobID, req.DisplayJobName, actionType)
  179. return jobID, nil
  180. }
  181. func getCentersParamter(ctx *context.Context, req *GenerateTrainJobReq) ([]string, []string) {
  182. var centerID []string
  183. var centerName []string
  184. includeCenters := make(map[string]string)
  185. excludeCenters := make(map[string]string)
  186. if SpecialPools != nil {
  187. for _, pool := range SpecialPools.Pools {
  188. if !pool.IsExclusive && strings.Contains(req.ComputeResource, pool.Type) {
  189. org, _ := models.GetOrgByName(pool.Org)
  190. if org != nil {
  191. isOrgMember, _ := models.IsOrganizationMember(org.ID, ctx.User.ID)
  192. if isOrgMember {
  193. for _, info := range pool.Pool {
  194. includeCenters[info.Queue] = info.Value
  195. }
  196. } else {
  197. for _, info := range pool.Pool {
  198. excludeCenters[info.Queue] = info.Value
  199. }
  200. }
  201. }
  202. }
  203. }
  204. }
  205. if len(includeCenters) > 0 {
  206. //如果有专属资源池,根据专属资源池指定智算中心
  207. for k, v := range includeCenters {
  208. centerID = append(centerID, k)
  209. centerName = append(centerName, v)
  210. }
  211. } else if len(excludeCenters) > 0 {
  212. //否则,有要排除的中心,先获取所有中心,删除其中的排除中心,得到指定的智算中心
  213. allCenters := make(map[string]string)
  214. specs, err := GetResourceSpecs(req.ProcessType)
  215. if err == nil {
  216. for _, info := range specs.Infos {
  217. for _, center := range info.Centers {
  218. allCenters[center.ID] = center.Name
  219. }
  220. }
  221. }
  222. for k, _ := range excludeCenters {
  223. delete(allCenters, k)
  224. }
  225. for k, v := range allCenters {
  226. centerID = append(centerID, k)
  227. centerName = append(centerName, v)
  228. }
  229. }
  230. return centerID, centerName
  231. }
  232. func TransTrainJobStatus(status string) string {
  233. if status == models.GrampusStatusPending {
  234. status = models.GrampusStatusWaiting
  235. }
  236. return strings.ToUpper(status)
  237. }
  238. func InitSpecialPool() {
  239. if SpecialPools == nil && setting.Grampus.SpecialPools != "" {
  240. json.Unmarshal([]byte(setting.Grampus.SpecialPools), &SpecialPools)
  241. }
  242. }
  243. func GetNpuModelRemoteObsUrl(jobName string) string {
  244. return "s3:///" + BucketRemote + "/" + GetNpuModelObjectKey(jobName)
  245. }
  246. func GetNpuModelObjectKey(jobName string) string {
  247. return setting.CodePathPrefix + jobName + RemoteModelPath
  248. }
  249. func GetRemoteEndPoint(aiCenterID string) string {
  250. var endPoint string
  251. for _, info := range setting.CenterInfos.Info {
  252. if info.CenterID == aiCenterID {
  253. endPoint = info.Endpoint
  254. break
  255. }
  256. }
  257. return endPoint
  258. }
  259. func GetCenterProxy(aiCenterID string) string {
  260. var proxy string
  261. for _, info := range setting.CenterInfos.Info {
  262. if info.CenterID == aiCenterID {
  263. proxy = info.StorageProxyServer
  264. break
  265. }
  266. }
  267. return proxy
  268. }