Skip to content
This repository has been archived by the owner on Jun 6, 2024. It is now read-only.

Tools for encrypting segments #86

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions dashlivesim/dashlib/structops.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def str_to_uint64(string8):
"8-character string to unsigned int64."
return unpack(">Q", string8)[0]

def uint8_to_str(uint8):
"Unsigned int8 to string."
return pack(">B", uint8)

def uint32_to_str(uint32):
"Unsigned int32 to string."
return pack(">I", uint32)
Expand Down
Empty file added dashlivesim/encrypt/__init__.py
Empty file.
203 changes: 203 additions & 0 deletions dashlivesim/encrypt/encrypt_segments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
"""Encrypt segments using Bento4's mp4encrypt."""

import os
from os.path import join, basename, dirname
from shutil import copyfile
import base64
import json
from argparse import ArgumentParser

from dashlivesim.vodanalyzer.dashanalyzer import DashAnalyzer
from dashlivesim.encrypt.fix_media_segments import fix_segment

KEY = '+5f6+rqidg+YaZG/0IyQcA=='
IV = '0101020304050607'
KIDGUUID = '3712a6d0-617e-43ca-b655-3475a2ac9135'

DRM_NAMES = {'edef8ba9-79d6-4ace-a3c8-27dcd51d21ed': 'Widevine',
'9a04f079-9840-4286-ab92-e65be0885f95': 'MSPR 2.0'}

ENCRYPT_TEMPLATE = ("mp4encrypt --method MPEG-CENC --fragments-info %(init)s "
"--key %(track_id)d:%(key_hex)s:%(iv)s "
"--property %(track_id)d:KID:%(kid_hex)s "
"--global-option mpeg-cenc.iv-size-8:true "
"%(infile)s %(outfile)s")

INIT_ENC_TEMPLATE = ("mp4encrypt --method MPEG-CENC "
"--key %(track_id)d:%(key_hex)s:%(iv)s "
"--property %(track_id)d:KID:%(kid_hex)s "
"--global-option mpeg-cenc.iv-size-8:true "
"%(infile)s %(outfile)s")

CP_GEN_TEMPLATE = ('<ContentProtection '
'schemeIdUri="urn:mpeg:dash:mp4protection:2011" '
'value="cenc" '
'cenc:default_KID="%(kid)s" />\n')

CP_DRM_TEMPLATE = ('<ContentProtection '
'value = "%(name)s" '
'schemeIdUri = '
'"urn:uuid:%(system_id)s">\n'
'<cenc:pssh>%(pssh)s</cenc:pssh>\n'
'</ContentProtection>\n')


def generate_mpd_cp_part(key_data):
parts = []
parts.append(CP_GEN_TEMPLATE % key_data)
for drm in key_data['drm_data']:
drm['name'] = DRM_NAMES[drm['system_id']]
parts.append(CP_DRM_TEMPLATE % drm)
return "".join(parts)


def read_drm_info(infile):
"Read drm_info in JSON format produced by extract_drm_info"
with open(infile, 'rb') as ifh:
json_data = json.load(ifh)
return json_data


def get_kid_and_key(key_data):
key = key_data['cek']
kid = key_data['kid']
kid_hex = kid.replace('-', '')
return kid_hex, key


def print_mpd_data(key_data, drm_nr, seg_nr):
print("#### MPD DATA %d segment=%d #####\n%s\n" %
(drm_nr, seg_nr, generate_mpd_cp_part(key_data)))


def encrypt(in_manifest, out_dir, drm_data, rotation_interval, iv):
"Encrypt a DASH asset by modifying the manifest and encrypting the segments."
dash_analyzer = DashAnalyzer(in_manifest)
dash_analyzer.initMedia()

drm_nr = 0

adaptation_sets = dash_analyzer.mpdProcessor.adaptation_sets
for aset in adaptation_sets:
content_type = aset.content_type
if content_type == 'video':
reps_data = dash_analyzer.as_data[content_type]
for rep in aset.representations:
for r in reps_data['reps']:
if r['representation'] == rep:
rep_data = r
break
else:
raise ValueError("No representation found")
init_path = rep_data['relInitPath']
out_seg_dir = join(out_dir, dirname(init_path))
drm_nr = 1
key_data = drm_data[drm_nr]
kid_hex, key = get_kid_and_key(key_data)
if not os.path.exists(out_seg_dir):
os.makedirs(out_seg_dir)
in_init_path = rep_data['absInitPath']
out_init_path = join(out_seg_dir, basename(init_path))
encrypt_segment("",
in_init_path,
out_init_path,
track_id=2,
key=key,
kid_hex=kid_hex,
iv=iv,
template=INIT_ENC_TEMPLATE)

for i in range(rep_data['firstNumber'],
rep_data['lastNumber'] + 1):
#if i > 20: # TODO remove this
# break
if i == rep_data['firstNumber']:
print_mpd_data(key_data, drm_nr, i)
if rotation_interval > 0:
rel_nr = i - rep_data['firstNumber']
old_drm_nr = drm_nr
new_drm_nr = rel_nr // rotation_interval
if new_drm_nr != old_drm_nr:
drm_nr = new_drm_nr
key_data = drm_data[drm_nr]
kid_hex, key = get_kid_and_key(key_data)
print_mpd_data(key_data, drm_nr, i)
in_seg_rel_path = rep_data['relMediaPath'] % i
in_seg_abs_path = rep_data['absMediaPath'] % i
out_seg_abs_path = join(out_dir, in_seg_rel_path)
out_seg_tmp_path = join(out_dir, 'tmp_out.m4s')
encrypt_segment(rep_data['absInitPath'],
in_seg_abs_path,
out_seg_tmp_path,
track_id=2,
key=key,
kid_hex=kid_hex,
iv=iv)
fix_segment(out_seg_tmp_path, out_seg_abs_path, kid_hex, 8)
os.unlink(out_seg_tmp_path)

else:
reps_data = dash_analyzer.as_data[content_type]
for rep in aset.representations:
for r in reps_data['reps']:
if r['representation'] == rep:
rep_data = r
break
else:
raise ValueError("No representation found")
init_path = rep_data['relInitPath']
out_seg_dir = join(out_dir, dirname(init_path))
if not os.path.exists(out_seg_dir):
os.makedirs(out_seg_dir)
out_init_name = join(out_seg_dir, basename(init_path))
copyfile(rep_data['absInitPath'], out_init_name)
for i in range(rep_data['firstNumber'],
rep_data['lastNumber']):
in_seg_rel_path = rep_data['relMediaPath'] % i
in_seg_abs_path = rep_data['absMediaPath'] % i
out_seg_abs_path = join(out_dir, in_seg_rel_path)
copyfile(in_seg_abs_path, out_seg_abs_path)


def encrypt_segment(init_seg, in_seg, out_seg, track_id, key, kid_hex, iv,
template = ENCRYPT_TEMPLATE):
key_hex = base64tohex(key)
cmd = template % \
{'init': init_seg,
'track_id': track_id,
'key_hex': key_hex,
'iv': iv,
'kid_hex': kid_hex,
'infile': in_seg,
'outfile': out_seg
}
os.system(cmd)


def base64tohex(b64str):
"Translate from base64 to hex."
decoded_str = base64.b64decode(b64str)
return "".join(["%02x" % ord(c) for c in decoded_str])


def main():
parser = ArgumentParser()
parser.add_argument('in_manifest')
parser.add_argument('out_dir')
parser.add_argument('drm_file')
parser.add_argument('--rot', type=int, default=0, help="Rotate every n "
"segments")

args = parser.parse_args()
drm_data = read_drm_info(args.drm_file)
encrypt(args.in_manifest, args.out_dir, drm_data, args.rot, IV)

if __name__ == "__main__":
main()
#drm_data = read_drm_info(
# '/Users/tobbe/proj/github/DashIF/dash-live-source'
# '-simulator/dashlivesim/encrypt/DrmData.json')
#in_manifest = '/Users/tobbe/Sites/dash/vod/testpic_2s_2min/Manifest.mpd'
#out_dir = '/Users/tobbe/Sites/dash/vod/testpic_2s_2min_enc'
#encrypt(in_manifest, out_dir, drm_data, IV)
49 changes: 49 additions & 0 deletions dashlivesim/encrypt/extract_drm_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"Extract DRM data using the CPIX Python3 module."

import base64
import sys
import re
from collections import OrderedDict
import json

import cpix

PSSH_PATTERN = re.compile(b'<pssh[^>]+>([^<]+)</pssh>')

def extract_data(file_name):
with open(file_name, 'rb') as ifh:
xml = ifh.read()
cp = cpix.parse(xml)
cp_data = []
for key in cp.content_keys:
cek = key.cek
kid = str(key.kid)
cp_data.append({'kid': kid,
'cek': cek,
'drm_data': []
})
for drm_system in cp.drm_systems:
kid = str(drm_system.kid)
system_id = str(drm_system.system_id)
cp_data_b64 = drm_system.content_protection_data
cp_data_parts = base64.b64decode(cp_data_b64).split(b'\r\n')
pssh_data = cp_data_parts[0]
mobj = PSSH_PATTERN.match(pssh_data)
if mobj:
pssh = mobj.groups(1)[0].decode('utf-8')
else:
raise ValueError("Did not find pssh data")
for cpd in cp_data:
if cpd['kid'] == kid:
cpd['drm_data'].append({'system_id': system_id,
'pssh': pssh})
break
else:
raise ValueError("Did not find {kid}")

json_out = json.dumps(cp_data)
print(json_out)

if __name__ == "__main__":
extract_data(sys.argv[1])
116 changes: 116 additions & 0 deletions dashlivesim/encrypt/fix_media_segments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""Fix media segments by adding sgpd and sbgp boxes."""

from argparse import ArgumentParser

from dashlivesim.dashlib.mp4filter import MP4Filter
from dashlivesim.dashlib.structops import (str_to_uint32, uint32_to_str,
uint8_to_str, sint32_to_str,
str_to_sint32)

class MediaSegmentFilterError(Exception):
"Error in MediaSegmentFilter."


class FixEncryptedSegment(MP4Filter):
"""Add sgpd and sbgp boxes to an encrypted segment."""

def __init__(self, file_name, key_id, iv_size=8):
MP4Filter.__init__(self, file_name)
self.key_id = key_id
self.iv_size = iv_size
self.top_level_boxes_to_parse = ["moof"]
self.composite_boxes_to_parse = ['moof', 'traf']
self.senc_sample_count = None
self.size_change = 72 # sgpd + sbgp = 44 + 28

def process_trun(self, data):
"Get total duration from trun. Fix offset if self.size_change is non-zero."
flags = str_to_uint32(data[8:12]) & 0xffffff
sample_count = str_to_uint32(data[12:16])
pos = 16
data_offset_present = False
if flags & 0x1: # Data offset present
data_offset_present = True
pos += 4
if flags & 0x4:
pos += 4 # First sample flags present
sample_duration_present = flags & 0x100
sample_size_present = flags & 0x200
sample_flags_present = flags & 0x400
sample_comp_time_present = flags & 0x800
duration = 0
for _ in range(sample_count):
if sample_duration_present:
duration += str_to_uint32(data[pos:pos + 4])
pos += 4
else:
duration += self.default_sample_duration
if sample_size_present:
pos += 4
if sample_flags_present:
pos += 4
if sample_comp_time_present:
pos += 4
self.duration = duration

# Modify data_offset
output = data[:16]
if data_offset_present and self.size_change > 0:
offset = str_to_sint32(data[16:20])
offset += self.size_change
output += sint32_to_str(offset)
else:
output += data[16:20]
output += data[20:]
return output

def process_senc(self, data):
"Get the number of entries."
version_and_flags = str_to_uint32(data[8:12])
sample_count = str_to_uint32(data[12:16])
self.senc_sample_count = sample_count
# Skip parsing the rest
return data + self.generate_sgpd() + self.generate_sbgp()

def generate_sgpd(self):
"Generate an appropriate sgpd box."
output = uint32_to_str(44) + 'sgpd' + '\x01\x00\x00\x00' + 'seig'
output += uint32_to_str(20) # defaultLength
output += '\x00\x00\x00\x01' # nr groupEntries
output += '\x00\x00\x01' + uint8_to_str(self.iv_size)
output += self.key_id
assert len(output) == 44
return output

def generate_sbgp(self):
"Generate an appropriate sbgp box."
output = uint32_to_str(28) + 'sbgp' + '\x00\x00\x00\x00' + 'seig'
output += '\x00\x00\x00\x01' # nr entries
output += uint32_to_str(self.senc_sample_count)
output += '\x00\x01\x00\x01' # first local groupDescriptionIndex
assert len(output) == 28
return output


def fix_segment(infile, outfile, kid_hex, iv_size):
kid = kid_hex.decode('hex')
fs = FixEncryptedSegment(infile, kid, iv_size)
outdata = fs.filter()
with open(outfile, 'wb') as ofh:
ofh.write(outdata)


def main():
parser = ArgumentParser()
parser.add_argument('infile')
parser.add_argument('outfile')
parser.add_argument('kid_hex')
parser.add_argument('iv_size', type=int)

args = parser.parse_args()
print("'%s' %d" % (args.kid_hex, len(args.kid_hex)))
fix_segment(args.infile, args.outfile, args.kid_hex, args.iv_size)


if __name__ == "__main__":
main()