Commit 17660e17 authored by Fardale's avatar Fardale

refactor password in communication

parent 7ad2396f
......@@ -3,11 +3,11 @@ open Serialization_t
let cmd cpu ram time iteration port script_file pass addr args () =
let script = CCIO.with_in script_file CCIO.read_all in
let submission = {script; args; time; iteration; cpu; ram; pass} in
let submission = {script; args; time; iteration; cpu; ram} in
Lwt_io.with_connection
(Lwt_unix.ADDR_INET (Unix.inet_addr_of_string addr, port))
(fun (ic, oc) ->
Lwt_io.write_line oc (Serialization_j.string_of_query (`JOB submission))
Lwt_io.write_line oc (Serialization_j.string_of_query (pass, `JOB submission))
>>= fun () ->
Lwt_io.read ic
>>= fun json ->
......
......@@ -34,7 +34,7 @@ let end_job id sockaddr =
, CCList.remove_one ~eq:(fun c1 c2 -> c1.id = c2.id) j l ) ;
Lwt.return_unit
let rec launch_job () =
let rec launch_job pass () =
if not (CCDeque.is_empty jobs) then (
let j = CCDeque.take_front jobs in
match
......@@ -62,94 +62,88 @@ let rec launch_job () =
>>= fun () ->
Lwt_io.with_connection sockaddr (fun (_ic, oc) ->
Lwt_io.write_line oc
(Serialization_j.string_of_query (`COMPUTATION j)))
(Serialization_j.string_of_query (pass, `COMPUTATION j)))
(* TODO: check return value *)
>>= Lwt.pause
>>= launch_job )
>>= launch_job pass )
else Logs_lwt.debug (fun m -> m "No computation")
let server_handler pass port sockaddr (ic, oc) =
Lwt_io.read_line ic
>>= fun json ->
match CCResult.guard (fun () -> Serialization_j.query_of_string json) with
| Result.Ok query -> (
match query with
| `RESULT result ->
if pass = result.pass then
Lwt_io.write_line oc (Serialization_j.string_of_answer `Ok)
>>= (fun () -> Lwt_io.flush oc)
<&> Logs_lwt.debug (fun m ->
m "Receive result: %s"
(Serialization_j.string_of_result
{result with stdout= "<stdout>"; stderr= "<stderr>"}))
>>= fun () ->
Lwt.return
(Lwt.async (fun () ->
Lwt.join
[ ( if String.length result.stdout > 0 then
Lwt_io.with_file ~mode:Lwt_io.output
(Printf.sprintf "ocluster_%i_%i.out" (fst result.id)
| Result.Ok (query_pass, query) ->
if query_pass = pass then
match query with
| `RESULT result ->
Lwt_io.write_line oc (Serialization_j.string_of_answer `Ok)
>>= (fun () -> Lwt_io.flush oc)
<&> Logs_lwt.debug (fun m ->
m "Receive result: %s"
(Serialization_j.string_of_result
{result with stdout= "<stdout>"; stderr= "<stderr>"}))
>>= fun () ->
Lwt.return
(Lwt.async (fun () ->
Lwt.join
[ ( if String.length result.stdout > 0 then
Lwt_io.with_file ~mode:Lwt_io.output
(Printf.sprintf "ocluster_%i_%i.out" (fst result.id)
(snd result.id))
(fun oc -> Lwt_io.write oc result.stdout)
else Lwt.return_unit )
; ( if String.length result.stderr > 0 then
Lwt_io.with_file ~mode:Lwt_io.output
(Printf.sprintf "ocluster_%i_%i.err" (fst result.id)
(snd result.id))
(fun oc -> Lwt_io.write oc result.stderr)
else Lwt.return_unit )
; Lwt_io.with_file ~mode:Lwt_io.output
(Printf.sprintf "ocluster_%i_%i.log" (fst result.id)
(snd result.id))
(fun oc -> Lwt_io.write oc result.stdout)
else Lwt.return_unit )
; ( if String.length result.stderr > 0 then
Lwt_io.with_file ~mode:Lwt_io.output
(Printf.sprintf "ocluster_%i_%i.err" (fst result.id)
(snd result.id))
(fun oc -> Lwt_io.write oc result.stderr)
else Lwt.return_unit )
; Lwt_io.with_file ~mode:Lwt_io.output
(Printf.sprintf "ocluster_%i_%i.log" (fst result.id)
(snd result.id))
(fun oc ->
Lwt_io.write oc
(Printf.sprintf
"Job completed at %f\nReturn code: %s\n"
(Unix.time ())
(string_of_ret_code result.ret_code)))
; end_job result.id sockaddr >>= Lwt.pause >>= launch_job ]))
else
Lwt_io.write_line oc
(Serialization_j.string_of_answer (`Error "Wrong password"))
>>= (fun () -> Lwt_io.flush oc)
<&> Logs_lwt.warn (fun m -> m "Wrong password: %s" result.pass)
| `JOB submission ->
if pass = submission.pass then (
Lwt_io.write_line oc (Serialization_j.string_of_answer `Ok)
>>= (fun () -> Lwt_io.flush oc)
<&> Logs_lwt.debug (fun m ->
m "Receive submission: %s"
(Serialization_j.string_of_submission submission))
>>= fun () ->
let empty = CCDeque.is_empty jobs in
for i = 0 to submission.iteration - 1 do
let computation =
{ id= (!jobs_id, i)
; env=
[ Printf.sprintf "OCLUSTER_ARRAY_TASK_ID=%i" i
; Printf.sprintf "OCLUSTER_TASK_ID=%i" !jobs_id ]
; script= submission.script
; args= submission.args
; time= submission.time
; pass
; port
; cpu= submission.cpu
; ram= submission.ram }
in
CCDeque.push_back jobs computation
done ;
incr jobs_id ;
if empty then launch_job () else Lwt.return_unit )
else
Lwt_io.write_line oc
(Serialization_j.string_of_answer (`Error "Wrong password"))
>>= (fun () -> Lwt_io.flush oc)
<&> Logs_lwt.warn (fun m -> m "Wrong password: %s" submission.pass)
| _ ->
(fun oc ->
Lwt_io.write oc
(Printf.sprintf
"Job completed at %f\nReturn code: %s\n"
(Unix.time ())
(string_of_ret_code result.ret_code)))
; end_job result.id sockaddr >>= Lwt.pause
>>= launch_job pass ]))
| `JOB submission ->
Lwt_io.write_line oc (Serialization_j.string_of_answer `Ok)
>>= (fun () -> Lwt_io.flush oc)
<&> Logs_lwt.debug (fun m ->
m "Receive submission: %s"
(Serialization_j.string_of_submission submission))
>>= fun () ->
let empty = CCDeque.is_empty jobs in
for i = 0 to submission.iteration - 1 do
let computation =
{ id= (!jobs_id, i)
; env=
[ Printf.sprintf "OCLUSTER_ARRAY_TASK_ID=%i" i
; Printf.sprintf "OCLUSTER_TASK_ID=%i" !jobs_id ]
; script= submission.script
; args= submission.args
; time= submission.time
; port
; cpu= submission.cpu
; ram= submission.ram }
in
CCDeque.push_back jobs computation
done ;
incr jobs_id ;
if empty then launch_job pass () else Lwt.return_unit
| _ ->
Lwt_io.write_line oc
(Serialization_j.string_of_answer (`Error "Unwanted command"))
>>= (fun () -> Lwt_io.flush oc)
<&> Logs_lwt.warn (fun m -> m "Receive a unwanted command")
else
Lwt_io.write_line oc
(Serialization_j.string_of_answer (`Error "Unwanted command"))
(Serialization_j.string_of_answer (`Error "Wrong password"))
>>= (fun () -> Lwt_io.flush oc)
<&> Logs_lwt.warn (fun m -> m "Receive a unwanted command") )
<&> Logs_lwt.warn (fun m -> m "Wrong password: %s" query_pass)
| Result.Error e ->
Lwt_io.write_line oc
(Serialization_j.string_of_answer
......
......@@ -42,13 +42,13 @@ let run_computation (computation : computation) =
( if String.length stderr > max_std then
CCString.drop (String.length stderr - max_std) stderr
else stderr )
; ret_code= process_status_to_ret_code ret_code
; pass= computation.pass })
; ret_code= process_status_to_ret_code ret_code })
let rec send_result sockaddr result =
let rec send_result sockaddr pass result =
if%lwt
Lwt_io.with_connection sockaddr (fun (ic, oc) ->
Lwt_io.write_line oc (Serialization_j.string_of_query (`RESULT result))
Lwt_io.write_line oc
(Serialization_j.string_of_query (pass, `RESULT result))
>>= fun () ->
Lwt_io.flush oc
>>= fun () ->
......@@ -78,9 +78,9 @@ let rec send_result sockaddr result =
with End_of_file ->
Logs_lwt.err (fun m -> m "Error during the read of the answer: EOF")
>>= fun () -> Lwt.return true)
then send_result sockaddr result
then send_result sockaddr pass result
let handle_computation sockaddr computation () =
let handle_computation sockaddr pass computation () =
run_computation computation
>>= fun result ->
let sockaddr =
......@@ -90,7 +90,7 @@ let handle_computation sockaddr computation () =
| s ->
s
in
send_result sockaddr result
send_result sockaddr pass result
<&> Logs_lwt.debug (fun m ->
m "End computation %i,%i" (fst computation.id) (snd computation.id))
......@@ -100,26 +100,27 @@ let server_handler pass sockaddr (ic, oc) =
Lwt_io.read_line ic
>>= fun json ->
( match CCResult.guard (fun () -> Serialization_j.query_of_string json) with
| Result.Ok query -> (
match query with
| `COMPUTATION (computation : computation) ->
if pass = computation.pass then
Lwt_io.write_line oc (Serialization_j.string_of_answer `Ok)
>>= (fun () -> Lwt_io.flush oc)
<&> Logs_lwt.debug (fun m ->
m "Receive computation: %s"
(Serialization_j.string_of_computation
{computation with env= []; script= "<script>"}))
>|= fun () -> Lwt.async (handle_computation sockaddr computation)
else
Lwt_io.write_line oc
(Serialization_j.string_of_answer (`Error "Wrong password"))
>>= (fun () -> Lwt_io.flush oc)
<&> Logs_lwt.warn (fun m -> m "Wrong password: %s" computation.pass)
| `STAT ->
stat oc <&> Logs_lwt.debug (fun m -> m "Receive a stat command")
| _ ->
Logs_lwt.warn (fun m -> m "Receive a unwanted command") )
| Result.Ok (query_pass, query) ->
if query_pass = pass then
match query with
| `COMPUTATION (computation : computation) ->
Lwt_io.write_line oc (Serialization_j.string_of_answer `Ok)
>>= (fun () -> Lwt_io.flush oc)
<&> Logs_lwt.debug (fun m ->
m "Receive computation: %s"
(Serialization_j.string_of_computation
{computation with env= []; script= "<script>"}))
>|= fun () ->
Lwt.async (handle_computation sockaddr pass computation)
| `STAT ->
stat oc <&> Logs_lwt.debug (fun m -> m "Receive a stat command")
| _ ->
Logs_lwt.warn (fun m -> m "Receive a unwanted command")
else
Lwt_io.write_line oc
(Serialization_j.string_of_answer (`Error "Wrong password"))
>>= (fun () -> Lwt_io.flush oc)
<&> Logs_lwt.warn (fun m -> m "Wrong password: %s" query_pass)
| Result.Error e ->
Logs_lwt.err (fun m ->
m "Error during the reception of the computation: %s"
......
......@@ -3,19 +3,20 @@ type node = {addr: string; ~port <ocaml default="4242">: int; cpu: int; ram: int
type master_conf = {pass: string; ~port <ocaml default="4242">: int; nodes: node list}
type node_conf = {pass: string; ~port <ocaml default="4242">: int}
type computation = {id: (int * int); env: string list; script: string; args: string list; time: float option; pass: string; port: int; cpu: int; ram: int}
type computation = {id: (int * int); env: string list; script: string; args: string list; time: float option; port: int; cpu: int; ram: int}
type ret_code = [WEXITED of int | WSIGNALED of int | WSTOPPED of int]
type submission = {script: string; args: string list; time: float option; iteration: int <ocaml default="1">;
cpu: int <ocaml default="1">; ram: int <ocaml default="1024">;
pass: string}
cpu: int <ocaml default="1">; ram: int <ocaml default="1024">;}
type query = [COMPUTATION of computation | STAT | RESULT of result | JOB of submission]
type result =
{id: (int * int); stdout: string; stderr: string; ret_code: ret_code}
type query_data = [COMPUTATION of computation | STAT | RESULT of result | JOB of submission]
type query = (string * query_data)
type answer = [Ok | Error of string]
type stat = [OK]
type result =
{id: (int * int); stdout: string; stderr: string; ret_code: ret_code; pass: string}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment