master.ml 6.71 KB
Newer Older
Fardale's avatar
Fardale committed
1 2 3 4
open Lwt.Infix
open Serialization_t
module SHashtbl = CCHashtbl.Make (CCString)

5 6 7
(* TODO: store currently running jobs and check from time
 * to time that node are still alive *)

Fardale's avatar
Fardale committed
8 9 10 11 12 13 14
let nodes : (node * computation list) SHashtbl.t = SHashtbl.create 10

let jobs : computation CCDeque.t = CCDeque.create ()

let jobs_id = ref 0

let string_of_ret_code = function
Fardale's avatar
Fardale committed
15 16 17 18 19 20
  | `WEXITED i ->
      Printf.sprintf "WEXITED %i" i
  | `WSIGNALED s ->
      Printf.sprintf "WSIGNALED %i" s
  | `WSTOPPED s ->
      Printf.sprintf "WSTOPPED %i" s
Fardale's avatar
Fardale committed
21 22

let string_of_sockaddr = function
Fardale's avatar
Fardale committed
23 24 25 26
  | Unix.ADDR_UNIX s ->
      s
  | Unix.ADDR_INET (ip, _) ->
      Unix.string_of_inet_addr ip
Fardale's avatar
Fardale committed
27 28 29 30

let end_job id sockaddr =
  let n, l = SHashtbl.find nodes (string_of_sockaddr sockaddr) in
  let j = List.find (fun c -> c.id = id) l in
31 32
  Lwt.pause ()
  >|= fun () ->
Fardale's avatar
Fardale committed
33 34 35 36 37 38 39 40 41 42 43 44
  SHashtbl.replace nodes
    (string_of_sockaddr sockaddr)
    ( {n with cpu= n.cpu + j.cpu; ram= n.ram + j.ram}
    , CCList.remove ~eq:(fun c1 c2 -> c1.id = c2.id) ~key:j l )

let rec launch_job () =
  if not (CCDeque.is_empty jobs) then (
    let j = CCDeque.take_front jobs in
    match
      SHashtbl.fold
        (fun k (n, l) node ->
          match node with
Fardale's avatar
Fardale committed
45 46
          | Some x ->
              Some x
Fardale's avatar
Fardale committed
47 48 49 50 51 52 53
          | None ->
              if n.cpu >= j.cpu && n.ram >= j.ram then Some (k, n, l) else None
          )
        nodes None
    with
    | None ->
        CCDeque.push_front jobs j ;
Fardale's avatar
Fardale committed
54
        Logs_lwt.debug (fun m -> m "No free node")
Fardale's avatar
Fardale committed
55 56 57 58 59 60
    | Some (k, n, l) ->
        let sockaddr =
          Unix.ADDR_INET (Unix.inet_addr_of_string n.addr, n.port)
        in
        SHashtbl.replace nodes k
          ({n with cpu= n.cpu - j.cpu; ram= n.ram - j.ram}, j :: l) ;
Fardale's avatar
Fardale committed
61
        Logs_lwt.debug (fun m ->
Fardale's avatar
Fardale committed
62 63
            m "Send computation %i,%i to %s on %i" (fst j.id) (snd j.id) n.addr
              n.port )
Fardale's avatar
Fardale committed
64
        >>= fun () ->
Fardale's avatar
Fardale committed
65
        Lwt_io.with_connection sockaddr (fun (_ic, oc) ->
66 67
            Lwt_io.write_line oc
              (Serialization_j.string_of_query (`COMPUTATION j)) )
Fardale's avatar
Fardale committed
68
        (* TODO: check return value *)
Fardale's avatar
Fardale committed
69
        >>= Lwt.pause
Fardale's avatar
Fardale committed
70
        >>= launch_job )
71
  else Logs_lwt.debug (fun m -> m "No computation")
Fardale's avatar
Fardale committed
72

73
let server_handler pass port sockaddr (ic, oc) =
74
  Lwt_io.read_line ic
Fardale's avatar
Fardale committed
75 76 77 78 79
  >>= fun json ->
  match CCResult.guard (fun () -> Serialization_j.query_of_string json) with
  | Result.Ok query -> (
    match query with
    | `RESULT result ->
80
        if pass = result.pass then
81 82 83 84 85 86
          Lwt_io.write_line oc (Serialization_j.string_of_answer `Ok)
          >>= (fun () -> Lwt_io.flush oc)
          <&> Logs_lwt.debug (fun m ->
                  m "Receive result: %s"
                    (Serialization_j.string_of_result
                       {result with stdout= "<stdout>"; stderr= "<stderr>"}) )
Fardale's avatar
Fardale committed
87
          >>= fun () ->
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
          Lwt.return
            (Lwt.async (fun () ->
                 Lwt.join
                   [ Lwt_io.with_file ~mode:Lwt_io.output
                       (Printf.sprintf "ocluster_%i_%i.out" (fst result.id)
                          (snd result.id))
                       (fun oc -> Lwt_io.write oc result.stdout)
                   ; Lwt_io.with_file ~mode:Lwt_io.output
                       (Printf.sprintf "ocluster_%i_%i.err" (fst result.id)
                          (snd result.id))
                       (fun oc -> Lwt_io.write oc result.stderr)
                   ; Lwt_io.with_file ~mode:Lwt_io.output
                       (Printf.sprintf "ocluster_%i_%i.log" (fst result.id)
                          (snd result.id))
                       (fun oc ->
                         Lwt_io.write oc
                           (Printf.sprintf
                              "Job completed at %f\nReturn code: %s\n"
                              (Unix.time ())
                              (string_of_ret_code result.ret_code)) )
                   ; end_job result.id sockaddr >>= Lwt.pause >>= launch_job ]
             ))
Fardale's avatar
Fardale committed
110
        else
111 112 113 114
          Lwt_io.write_line oc
            (Serialization_j.string_of_answer (`Error "Wrong password"))
          >>= (fun () -> Lwt_io.flush oc)
          <&> Logs_lwt.warn (fun m -> m "Wrong password: %s" result.pass)
Fardale's avatar
Fardale committed
115 116
    | `JOB submission ->
        if pass = submission.pass then (
117 118 119 120 121
          Lwt_io.write_line oc (Serialization_j.string_of_answer `Ok)
          >>= (fun () -> Lwt_io.flush oc)
          <&> Logs_lwt.debug (fun m ->
                  m "Receive submission: %s"
                    (Serialization_j.string_of_submission submission) )
Fardale's avatar
Fardale committed
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
          >>= fun () ->
          let empty = CCDeque.is_empty jobs in
          for i = 0 to submission.iteration - 1 do
            let computation =
              { id= (!jobs_id, i)
              ; env=
                  [ Printf.sprintf "OCLUSTER_ARRAY_TASK_ID=%i" i
                  ; Printf.sprintf "OCLUSTER_TASK_ID=%i" !jobs_id ]
              ; script= submission.script
              ; time= submission.time
              ; pass
              ; port
              ; cpu= submission.cpu
              ; ram= submission.ram }
            in
            CCDeque.push_back jobs computation
          done ;
          incr jobs_id ;
          if empty then launch_job () else Lwt.return_unit )
        else
142 143 144 145
          Lwt_io.write_line oc
            (Serialization_j.string_of_answer (`Error "Wrong password"))
          >>= (fun () -> Lwt_io.flush oc)
          <&> Logs_lwt.warn (fun m -> m "Wrong password: %s" submission.pass)
Fardale's avatar
Fardale committed
146
    | _ ->
147 148 149 150
        Lwt_io.write_line oc
          (Serialization_j.string_of_answer (`Error "Unwanted command"))
        >>= (fun () -> Lwt_io.flush oc)
        <&> Logs_lwt.warn (fun m -> m "Receive a unwanted command") )
Fardale's avatar
Fardale committed
151
  | Result.Error e ->
152 153 154 155 156 157 158 159 160 161
      Lwt_io.write_line oc
        (Serialization_j.string_of_answer
           (`Error
             (Printf.sprintf
                "Error during the reception of the computation: %s"
                (Printexc.to_string e))))
      >>= (fun () -> Lwt_io.flush oc)
      <&> Logs_lwt.err (fun m ->
              m "Error during the reception of the computation: %s"
                (Printexc.to_string e) )
Fardale's avatar
Fardale committed
162 163 164

let stop_server resolver server _ = Lwt.wakeup_later resolver server

Fardale's avatar
Fardale committed
165
let cmd config () =
Fardale's avatar
Fardale committed
166 167 168 169 170
  let conf =
    Serialization_j.master_conf_of_string (CCIO.with_in config CCIO.read_all)
  in
  List.iter (fun n -> SHashtbl.add nodes n.addr (n, [])) conf.nodes ;
  let promise, resolver = Lwt.task () in
Fardale's avatar
Fardale committed
171
  Logs_lwt.info (fun m -> m "master at %i with pass %s" conf.port conf.pass)
Fardale's avatar
Fardale committed
172 173 174 175 176 177 178 179 180
  >>= fun () ->
  Lwt_io.establish_server_with_client_address
    (Unix.ADDR_INET (Unix.inet_addr_any, conf.port))
    (server_handler conf.pass conf.port)
  >>= fun server ->
  let _ = Lwt_unix.on_signal 15 (stop_server resolver server)
  and _ = Lwt_unix.on_signal 2 (stop_server resolver server) in
  promise
  >>= fun server ->
Fardale's avatar
Fardale committed
181
  Lwt_io.shutdown_server server
Fardale's avatar
Fardale committed
182
  <&> Logs_lwt.info (fun m -> m "Shuting down server")