You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grampus.go 7.8 kB

3 years ago
2 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
2 years ago
3 years ago
2 years ago
2 years ago
2 years ago
3 years ago
3 years ago
3 years ago
3 years ago
2 years ago
2 years ago
2 years ago
3 years ago
2 years ago
3 years ago
3 years ago
2 years ago
3 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
3 years ago
3 years ago
2 years ago
3 years ago
2 years ago
2 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
2 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
2 years ago
2 years ago
2 years ago
2 years ago
3 years ago
3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. package grampus
  2. import (
  3. "code.gitea.io/gitea/modules/cloudbrain"
  4. "encoding/json"
  5. "strings"
  6. "code.gitea.io/gitea/modules/setting"
  7. "code.gitea.io/gitea/models"
  8. "code.gitea.io/gitea/modules/context"
  9. "code.gitea.io/gitea/modules/log"
  10. "code.gitea.io/gitea/modules/notification"
  11. "code.gitea.io/gitea/modules/timeutil"
  12. )
  13. const (
  14. JobPath = "job/"
  15. ProcessorTypeNPU = "npu.huawei.com/NPU"
  16. ProcessorTypeGPU = "nvidia.com/gpu"
  17. GpuWorkDir = "/tmp/"
  18. NpuWorkDir = "/cache/"
  19. CodeArchiveName = "master.zip"
  20. )
  21. var (
  22. poolInfos *models.PoolInfos
  23. FlavorInfos *setting.StFlavorInfos
  24. ImageInfos *setting.StImageInfosModelArts
  25. SpecialPools *models.SpecialPools
  26. CommandPrepareScript = ";mkdir -p output;mkdir -p code;mkdir -p dataset;mkdir -p pretrainmodel;echo \"start loading script\";wget -q https://git.openi.org.cn/OpenIOSSG/%s/archive/master.zip;" +
  27. "echo \"finish loading script\";unzip -q master.zip;cd %s;chmod 777 downloader_for_obs uploader_for_npu downloader_for_minio uploader_for_gpu;"
  28. )
  29. type GenerateTrainJobReq struct {
  30. JobName string
  31. Command string
  32. ImageUrl string //与image_id二选一,都有的情况下优先image_url
  33. ImageId string
  34. DisplayJobName string
  35. Uuid string
  36. Description string
  37. CodeObsPath string
  38. BootFile string
  39. BootFileUrl string
  40. DataUrl string
  41. TrainUrl string
  42. WorkServerNumber int
  43. EngineID int64
  44. CommitID string
  45. IsLatestVersion string
  46. BranchName string
  47. PreVersionId int64
  48. PreVersionName string
  49. VersionCount int
  50. EngineName string
  51. TotalVersionCount int
  52. ComputeResource string
  53. ProcessType string
  54. DatasetNames string
  55. DatasetInfos map[string]models.DatasetInfo
  56. Params string
  57. ModelName string
  58. LabelName string
  59. CkptName string
  60. ModelVersion string
  61. PreTrainModelPath string
  62. PreTrainModelUrl string
  63. Spec *models.Specification
  64. CodeName string
  65. }
  66. func getEndPoint() string {
  67. index := strings.Index(setting.Endpoint, "//")
  68. endpoint := setting.Endpoint[index+2:]
  69. return endpoint
  70. }
  71. func getDatasetGrampus(datasetInfos map[string]models.DatasetInfo) []models.GrampusDataset {
  72. var datasetGrampus []models.GrampusDataset
  73. endPoint := getEndPoint()
  74. for _, datasetInfo := range datasetInfos {
  75. datasetGrampus = append(datasetGrampus, models.GrampusDataset{
  76. Name: datasetInfo.FullName,
  77. Bucket: setting.Bucket,
  78. EndPoint: endPoint,
  79. ObjectKey: datasetInfo.DataLocalPath + datasetInfo.FullName,
  80. })
  81. }
  82. return datasetGrampus
  83. }
  84. func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error) {
  85. createTime := timeutil.TimeStampNow()
  86. centerID, centerName := getCentersParamter(ctx, req)
  87. var datasetGrampus, modelGrampus []models.GrampusDataset
  88. var codeGrampus models.GrampusDataset
  89. if ProcessorTypeNPU == req.ProcessType {
  90. datasetGrampus = getDatasetGrampus(req.DatasetInfos)
  91. if len(req.ModelName) != 0 {
  92. modelGrampus = []models.GrampusDataset{
  93. {
  94. Name: req.ModelName,
  95. Bucket: setting.Bucket,
  96. EndPoint: getEndPoint(),
  97. ObjectKey: req.PreTrainModelPath,
  98. },
  99. }
  100. }
  101. codeGrampus = models.GrampusDataset{
  102. Name: req.CodeName,
  103. Bucket: setting.Bucket,
  104. EndPoint: getEndPoint(),
  105. ObjectKey: req.CodeObsPath + cloudbrain.DefaultBranchName + ".zip",
  106. }
  107. }
  108. jobResult, err := createJob(models.CreateGrampusJobRequest{
  109. Name: req.JobName,
  110. Tasks: []models.GrampusTasks{
  111. {
  112. Name: req.JobName,
  113. Command: req.Command,
  114. ResourceSpecId: req.Spec.SourceSpecId,
  115. ImageId: req.ImageId,
  116. ImageUrl: req.ImageUrl,
  117. CenterID: centerID,
  118. CenterName: centerName,
  119. ReplicaNum: 1,
  120. Datasets: datasetGrampus,
  121. Models: modelGrampus,
  122. Code: codeGrampus,
  123. BootFile: req.BootFile,
  124. },
  125. },
  126. })
  127. if err != nil {
  128. log.Error("createJob failed: %v", err.Error())
  129. return err
  130. }
  131. jobID := jobResult.JobInfo.JobID
  132. err = models.CreateCloudbrain(&models.Cloudbrain{
  133. Status: TransTrainJobStatus(jobResult.JobInfo.Status),
  134. UserID: ctx.User.ID,
  135. RepoID: ctx.Repo.Repository.ID,
  136. JobID: jobID,
  137. JobName: req.JobName,
  138. DisplayJobName: req.DisplayJobName,
  139. JobType: string(models.JobTypeTrain),
  140. Type: models.TypeC2Net,
  141. Uuid: req.Uuid,
  142. DatasetName: req.DatasetNames,
  143. CommitID: req.CommitID,
  144. IsLatestVersion: req.IsLatestVersion,
  145. ComputeResource: req.ComputeResource,
  146. ImageID: req.ImageId,
  147. TrainUrl: req.TrainUrl,
  148. BranchName: req.BranchName,
  149. Parameters: req.Params,
  150. BootFile: req.BootFile,
  151. DataUrl: req.DataUrl,
  152. Description: req.Description,
  153. WorkServerNumber: req.WorkServerNumber,
  154. EngineName: req.EngineName,
  155. VersionCount: req.VersionCount,
  156. TotalVersionCount: req.TotalVersionCount,
  157. CreatedUnix: createTime,
  158. UpdatedUnix: createTime,
  159. Spec: req.Spec,
  160. ModelName: req.ModelName,
  161. ModelVersion: req.ModelVersion,
  162. LabelName: req.LabelName,
  163. PreTrainModelUrl: req.PreTrainModelUrl,
  164. CkptName: req.CkptName,
  165. })
  166. if err != nil {
  167. log.Error("CreateCloudbrain(%s) failed:%v", req.DisplayJobName, err.Error())
  168. return err
  169. }
  170. var actionType models.ActionType
  171. if req.ComputeResource == models.NPUResource {
  172. actionType = models.ActionCreateGrampusNPUTrainTask
  173. } else if req.ComputeResource == models.GPUResource {
  174. actionType = models.ActionCreateGrampusGPUTrainTask
  175. }
  176. notification.NotifyOtherTask(ctx.User, ctx.Repo.Repository, jobID, req.DisplayJobName, actionType)
  177. return nil
  178. }
  179. func getCentersParamter(ctx *context.Context, req *GenerateTrainJobReq) ([]string, []string) {
  180. var centerID []string
  181. var centerName []string
  182. includeCenters := make(map[string]string)
  183. excludeCenters := make(map[string]string)
  184. if SpecialPools != nil {
  185. for _, pool := range SpecialPools.Pools {
  186. if !pool.IsExclusive && strings.Contains(req.ComputeResource, pool.Type) {
  187. org, _ := models.GetOrgByName(pool.Org)
  188. if org != nil {
  189. isOrgMember, _ := models.IsOrganizationMember(org.ID, ctx.User.ID)
  190. if isOrgMember {
  191. for _, info := range pool.Pool {
  192. includeCenters[info.Queue] = info.Value
  193. }
  194. } else {
  195. for _, info := range pool.Pool {
  196. excludeCenters[info.Queue] = info.Value
  197. }
  198. }
  199. }
  200. }
  201. }
  202. }
  203. if len(includeCenters) > 0 {
  204. //如果有专属资源池,根据专属资源池指定智算中心
  205. for k, v := range includeCenters {
  206. centerID = append(centerID, k)
  207. centerName = append(centerName, v)
  208. }
  209. } else if len(excludeCenters) > 0 {
  210. //否则,有要排除的中心,先获取所有中心,删除其中的排除中心,得到指定的智算中心
  211. allCenters := make(map[string]string)
  212. specs, err := GetResourceSpecs(req.ProcessType)
  213. if err == nil {
  214. for _, info := range specs.Infos {
  215. for _, center := range info.Centers {
  216. allCenters[center.ID] = center.Name
  217. }
  218. }
  219. }
  220. for k, _ := range excludeCenters {
  221. delete(allCenters, k)
  222. }
  223. for k, v := range allCenters {
  224. centerID = append(centerID, k)
  225. centerName = append(centerName, v)
  226. }
  227. }
  228. return centerID, centerName
  229. }
  230. func TransTrainJobStatus(status string) string {
  231. if status == models.GrampusStatusPending {
  232. status = models.GrampusStatusWaiting
  233. }
  234. return strings.ToUpper(status)
  235. }
  236. func InitSpecialPool() {
  237. if SpecialPools == nil && setting.Grampus.SpecialPools != "" {
  238. json.Unmarshal([]byte(setting.Grampus.SpecialPools), &SpecialPools)
  239. }
  240. }