master.ml 5.66 KB
Newer Older
Fardale's avatar
Fardale committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
open Lwt.Infix
open Serialization_t
module SHashtbl = CCHashtbl.Make (CCString)

let nodes : (node * computation list) SHashtbl.t = SHashtbl.create 10

let jobs : computation CCDeque.t = CCDeque.create ()

let jobs_id = ref 0

let string_of_ret_code = function
  | `WEXITED i -> Printf.sprintf "WEXITED %i" i
  | `WSIGNALED s -> Printf.sprintf "WSIGNALED %i" s
  | `WSTOPPED s -> Printf.sprintf "WSTOPPED %i" s

let string_of_sockaddr = function
  | Unix.ADDR_UNIX s -> s
  | Unix.ADDR_INET (ip, _) -> Unix.string_of_inet_addr ip

let end_job id sockaddr =
  let n, l = SHashtbl.find nodes (string_of_sockaddr sockaddr) in
  let j = List.find (fun c -> c.id = id) l in
23
24
  Lwt.pause ()
  >|= fun () ->
Fardale's avatar
Fardale committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
  SHashtbl.replace nodes
    (string_of_sockaddr sockaddr)
    ( {n with cpu= n.cpu + j.cpu; ram= n.ram + j.ram}
    , CCList.remove ~eq:(fun c1 c2 -> c1.id = c2.id) ~key:j l )

let rec launch_job () =
  if not (CCDeque.is_empty jobs) then (
    let j = CCDeque.take_front jobs in
    match
      SHashtbl.fold
        (fun k (n, l) node ->
          match node with
          | Some x -> Some x
          | None ->
              if n.cpu >= j.cpu && n.ram >= j.ram then Some (k, n, l) else None
          )
        nodes None
    with
    | None ->
        CCDeque.push_front jobs j ;
Fardale's avatar
Fardale committed
45
        Logs_lwt.debug (fun m -> m "No free node\n")
Fardale's avatar
Fardale committed
46
47
48
49
50
51
    | Some (k, n, l) ->
        let sockaddr =
          Unix.ADDR_INET (Unix.inet_addr_of_string n.addr, n.port)
        in
        SHashtbl.replace nodes k
          ({n with cpu= n.cpu - j.cpu; ram= n.ram - j.ram}, j :: l) ;
Fardale's avatar
Fardale committed
52
53
54
55
56
        Logs_lwt.debug (fun m ->
            m "Send computation %i,%i to %s on %i\n" (fst j.id) (snd j.id)
              n.addr n.port )
        >>= fun () ->
        Lwt_io.flush Lwt_io.stderr
Fardale's avatar
Fardale committed
57
58
59
60
61
        >>= fun () ->
        Lwt_io.with_connection sockaddr (fun (_ic, oc) ->
            Lwt_io.write oc (Serialization_j.string_of_query (`COMPUTATION j))
        )
        (* TODO: check return value *)
62
        >>= Lwt.pause >>= launch_job )
Fardale's avatar
Fardale committed
63
  else Logs_lwt.debug (fun m -> m "No compuatiton\n")
Fardale's avatar
Fardale committed
64
65

let server_handler pass port sockaddr (ic, _oc) =
Fardale's avatar
Fardale committed
66
67
68
69
70
71
72
73
  Lwt_io.read ic
  >>= fun json ->
  match CCResult.guard (fun () -> Serialization_j.query_of_string json) with
  | Result.Ok query -> (
    match query with
    | `RESULT result ->
        if pass = result.pass then (
          (*Lwt_io.write oc "true"
Fardale's avatar
Fardale committed
74
              <&>*)
Fardale's avatar
Fardale committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
            Logs_lwt.debug (fun m ->
                m "Receive result: %s\n"
                  (Serialization_j.string_of_result
                     {result with stdout= "<stdout>"; stderr= "<stderr>"}) )
          >>= fun () ->
          Lwt_io.with_file ~mode:Lwt_io.output
            (Printf.sprintf "ocluster_%i_%i.out" (fst result.id)
               (snd result.id))
            (fun oc -> Lwt_io.write oc result.stdout)
          >>= fun () ->
          Lwt_io.with_file ~mode:Lwt_io.output
            (Printf.sprintf "ocluster_%i_%i.err" (fst result.id)
               (snd result.id))
            (fun oc -> Lwt_io.write oc result.stderr)
          >>= fun () ->
          Lwt_io.with_file ~mode:Lwt_io.output
            (Printf.sprintf "ocluster_%i_%i.log" (fst result.id)
               (snd result.id))
            (fun oc ->
              Lwt_io.write oc
                (Printf.sprintf "Job completed at %f\nReturn code: %s\n"
                   (Unix.time ())
                   (string_of_ret_code result.ret_code)) )
          >>= fun () -> end_job result.id sockaddr ; launch_job () )
        else
          (*Lwt_io.write oc "Wrong password"
Fardale's avatar
Fardale committed
101
              <&>*)
Fardale's avatar
Fardale committed
102
103
104
105
          Logs_lwt.warn (fun m -> m "Wrong password: %s\n" result.pass)
    | `JOB submission ->
        if pass = submission.pass then (
          (*Lwt_io.write oc "true"
Fardale's avatar
Fardale committed
106
              <&>*)
Fardale's avatar
Fardale committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
            Logs_lwt.debug (fun m ->
                m "Receive submission: %s\n"
                  (Serialization_j.string_of_submission submission) )
          >>= fun () ->
          let empty = CCDeque.is_empty jobs in
          for i = 0 to submission.iteration - 1 do
            let computation =
              { id= (!jobs_id, i)
              ; env=
                  [ Printf.sprintf "OCLUSTER_ARRAY_TASK_ID=%i" i
                  ; Printf.sprintf "OCLUSTER_TASK_ID=%i" !jobs_id ]
              ; script= submission.script
              ; time= submission.time
              ; pass
              ; port
              ; cpu= submission.cpu
              ; ram= submission.ram }
            in
            CCDeque.push_back jobs computation
          done ;
          incr jobs_id ;
          if empty then launch_job () else Lwt.return_unit )
        else
          (*Lwt_io.write oc "Wrong password"
Fardale's avatar
Fardale committed
131
              <&>*)
Fardale's avatar
Fardale committed
132
133
134
135
136
137
          Logs_lwt.warn (fun m -> m "Wrong password: %s\n" submission.pass)
    | _ -> Logs_lwt.warn (fun m -> m "Receive a unwanted command\n") )
  | Result.Error e ->
      Logs_lwt.err (fun m ->
          m "Error during the reception of the computation: %s\n"
            (Printexc.to_string e) )
Fardale's avatar
Fardale committed
138
139
140
141
142
143
144
145
146

let stop_server resolver server _ = Lwt.wakeup_later resolver server

let cmd config =
  let conf =
    Serialization_j.master_conf_of_string (CCIO.with_in config CCIO.read_all)
  in
  List.iter (fun n -> SHashtbl.add nodes n.addr (n, [])) conf.nodes ;
  let promise, resolver = Lwt.task () in
Fardale's avatar
Fardale committed
147
  Logs_lwt.app (fun m -> m "master at %i with pass %s\n" conf.port conf.pass)
Fardale's avatar
Fardale committed
148
149
150
151
152
153
154
155
156
  >>= fun () ->
  Lwt_io.establish_server_with_client_address
    (Unix.ADDR_INET (Unix.inet_addr_any, conf.port))
    (server_handler conf.pass conf.port)
  >>= fun server ->
  let _ = Lwt_unix.on_signal 15 (stop_server resolver server)
  and _ = Lwt_unix.on_signal 2 (stop_server resolver server) in
  promise
  >>= fun server ->
Fardale's avatar
Fardale committed
157
158
  Lwt_io.shutdown_server server
  <&> Logs_lwt.app (fun m -> m "Shuting down server\n")